MLIR 22.0.0git
BasicPtxBuilderInterface.cpp
Go to the documentation of this file.
1//===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops
10// automatically. It is used by NVVM to LLVM pass.
11//
12//===----------------------------------------------------------------------===//
13
16#include "mlir/IR/Diagnostics.h"
17#include "mlir/IR/Location.h"
18#include "mlir/IR/MLIRContext.h"
19
20#include "mlir/Support/LLVM.h"
21#include "llvm/ADT/StringExtras.h"
22#include "llvm/ADT/TypeSwitch.h"
23#include "llvm/Support/DebugLog.h"
24#include "llvm/Support/FormatVariadic.h"
25#include "llvm/Support/LogicalResult.h"
26#include "llvm/Support/Regex.h"
27
28#define DEBUG_TYPE "ptx-builder"
29
30//===----------------------------------------------------------------------===//
31// BasicPtxBuilderInterface
32//===----------------------------------------------------------------------===//
33
34#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.cpp.inc"
35
36using namespace mlir;
37using namespace NVVM;
38
39static constexpr int64_t kSharedMemorySpace = 3;
40
41static FailureOr<char> getRegisterType(Type type, Location loc) {
42 MLIRContext *ctx = type.getContext();
43 auto i16 = IntegerType::get(ctx, 16);
44 auto i32 = IntegerType::get(ctx, 32);
45 auto f32 = Float32Type::get(ctx);
46
47 auto getRegisterTypeForScalar = [&](Type type) -> FailureOr<char> {
48 if (type.isInteger(1))
49 return 'b';
50 if (type.isInteger(16))
51 return 'h';
52 if (type.isInteger(32))
53 return 'r';
54 if (type.isInteger(64))
55 return 'l';
56 if (type.isF32())
57 return 'f';
58 if (type.isF64())
59 return 'd';
60 if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
61 // Shared address spaces is addressed with 32-bit pointers.
62 if (ptr.getAddressSpace() == kSharedMemorySpace) {
63 return 'r';
64 }
65 return 'l';
66 }
67 // register type for struct is not supported.
69 loc, "The register type could not be deduced from MLIR type. The ")
70 << type
71 << " is not supported. Supported types are:"
72 "i1, i16, i32, i64, f32, f64,"
73 "pointers.\nPlease use llvm.bitcast if you have different type. "
74 "\nSee the constraints from here: "
75 "https://docs.nvidia.com/cuda/inline-ptx-assembly/"
76 "index.html#constraints";
77 return failure();
78 };
79
80 // Packed registers
81 if (auto v = dyn_cast<VectorType>(type)) {
82 assert(v.getNumDynamicDims() == 0 && "Dynamic vectors are not supported");
83
84 int64_t lanes = v.getNumElements();
85 Type elem = v.getElementType();
86
87 // Case 1. Single vector
88 if (lanes <= 1)
89 return getRegisterTypeForScalar(elem);
90
91 // Case 2. Packed registers
92 Type widened = elem;
93 switch (lanes) {
94
95 case 2:
96 if (elem.isF16() || elem.isBF16()) // vector<2xf16>
97 widened = f32;
98 else if (elem.isFloat(8)) // vector<2xf8>
99 widened = i16;
100 break;
101 case 4:
102 if (elem.isInteger(8)) // vector<i8x4>
103 widened = i32;
104 else if (elem.isFloat(8)) // vector<f8x4>
105 widened = f32;
106 else if (elem.isFloat(4)) // vector<f4x4>
107 widened = i16;
108 break;
109 // Other packing is not supported
110 default:
111 break;
112 }
113 return getRegisterTypeForScalar(widened);
114 }
115
116 return getRegisterTypeForScalar(type);
117}
118
119static FailureOr<char> getRegisterType(Value v, Location loc) {
120 if (v.getDefiningOp<LLVM::ConstantOp>())
121 return 'n';
122 return getRegisterType(v.getType(), loc);
123}
124
125/// Extract every element of a struct value.
127 Location loc, Value structVal) {
128 auto structTy = dyn_cast<LLVM::LLVMStructType>(structVal.getType());
129 assert(structTy && "expected LLVM struct");
130
131 SmallVector<Value> elems;
132 for (unsigned i : llvm::seq<unsigned>(0, structTy.getBody().size()))
133 elems.push_back(LLVM::ExtractValueOp::create(rewriter, loc, structVal, i));
134
135 return elems;
136}
137
139 LDBG() << v << "\t Modifier : " << itype << "\n";
140 registerModifiers.push_back(itype);
141
142 Location loc = interfaceOp->getLoc();
143 auto getModifier = [&]() -> const char * {
144 switch (itype) {
146 return "";
148 return "=";
150 // "Read-Write modifier is not actually supported
151 // Interface will change it to "=" later and add integer mapping
152 return "+";
153 }
154 llvm_unreachable("Unknown PTX register modifier");
155 };
156
157 auto addValue = [&](Value v) {
158 if (itype == PTXRegisterMod::Read) {
159 ptxOperands.push_back(v);
160 return;
161 }
162 if (itype == PTXRegisterMod::ReadWrite)
163 ptxOperands.push_back(v);
164 hasResult = true;
165 };
166
167 llvm::raw_string_ostream ss(registerConstraints);
168 // Handle Structs
169 if (auto stype = dyn_cast<LLVM::LLVMStructType>(v.getType())) {
170 if (itype == PTXRegisterMod::Write) {
171 addValue(v);
172 }
173 for (auto [idx, t] : llvm::enumerate(stype.getBody())) {
174 if (itype != PTXRegisterMod::Write) {
175 Value extractValue =
176 LLVM::ExtractValueOp::create(rewriter, loc, v, idx);
177 addValue(extractValue);
178 }
179 if (itype == PTXRegisterMod::ReadWrite) {
180 ss << idx << ",";
181 } else {
182 FailureOr<char> regType = getRegisterType(t, loc);
183 if (failed(regType))
184 return rewriter.notifyMatchFailure(loc,
185 "failed to get register type");
186 ss << getModifier() << regType.value() << ",";
187 }
188 }
189 return success();
190 }
191 // Handle Scalars
192 addValue(v);
193 FailureOr<char> regType = getRegisterType(v, loc);
194 if (failed(regType))
195 return rewriter.notifyMatchFailure(loc, "failed to get register type");
196 ss << getModifier() << regType.value() << ",";
197 return success();
198}
199
200/// Check if the operation needs to pack and unpack results.
201static bool
202needsPackUnpack(BasicPtxBuilderInterface interfaceOp,
203 bool needsManualRegisterMapping,
204 SmallVectorImpl<PTXRegisterMod> &registerModifiers) {
205 if (needsManualRegisterMapping)
206 return false;
207 const unsigned writeOnlyVals = interfaceOp->getNumResults();
208 const unsigned readWriteVals =
209 llvm::count_if(registerModifiers, [](PTXRegisterMod m) {
210 return m == PTXRegisterMod::ReadWrite;
211 });
212 return (writeOnlyVals + readWriteVals) > 1;
213}
214
215/// Pack the result types of the interface operation.
216/// If the operation has multiple results, it packs them into a struct
217/// type. Otherwise, it returns the original result types.
218static SmallVector<Type>
219packResultTypes(BasicPtxBuilderInterface interfaceOp,
220 bool needsManualRegisterMapping,
221 SmallVectorImpl<PTXRegisterMod> &registerModifiers,
222 SmallVectorImpl<Value> &ptxOperands) {
223 MLIRContext *ctx = interfaceOp->getContext();
224 TypeRange resultRange = interfaceOp->getResultTypes();
225
226 if (!needsPackUnpack(interfaceOp, needsManualRegisterMapping,
227 registerModifiers)) {
228 // Single value path:
229 if (interfaceOp->getResults().size() == 1)
230 return SmallVector<Type>{resultRange.front()};
231
232 // No declared results: if there is an RW, forward its type.
233 for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
235 return SmallVector<Type>{v.getType()};
236 }
237
238 SmallVector<Type> packed;
239 for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
241 packed.push_back(v.getType());
242 for (Type t : resultRange)
243 packed.push_back(t);
244
245 if (packed.empty())
246 return {};
247
248 auto sTy = LLVM::LLVMStructType::getLiteral(ctx, packed, /*isPacked=*/false);
249 return SmallVector<Type>{sTy};
250}
251
252/// Canonicalize the register constraints:
253/// - Turn every "+X" into "=X"
254/// - Append (at the very end) the 0-based indices of tokens that were "+X"
255/// Examples:
256/// "+f,+f,+r,=r,=r,r,r" -> "=f,=f,=r,=r,=r,r,r,0,1,2"
257/// "+f,+f,+r,=r,=r" -> "=f,=f,=r,=r,=r,0,1,2"
258static std::string canonicalizeRegisterConstraints(llvm::StringRef csv) {
261 SmallVector<unsigned> plusIdx;
262
263 csv.split(toks, ',');
264 out.reserve(toks.size() + 8);
265
266 for (unsigned i = 0, e = toks.size(); i < e; ++i) {
267 StringRef t = toks[i].trim();
268 if (t.consume_front("+")) {
269 plusIdx.push_back(i);
270 out.push_back(("=" + t).str());
271 } else {
272 out.push_back(t.str());
273 }
274 }
275
276 // Append indices of original "+X" tokens.
277 for (unsigned idx : plusIdx)
278 out.push_back(std::to_string(idx));
279
280 // Join back to CSV.
281 std::string result;
282 result.reserve(csv.size() + plusIdx.size() * 2);
283 llvm::raw_string_ostream os(result);
284 for (size_t i = 0; i < out.size(); ++i) {
285 if (i)
286 os << ',';
287 os << out[i];
288 }
289 return os.str();
290}
291
292constexpr llvm::StringLiteral kReadWritePrefix{"rw"};
293constexpr llvm::StringLiteral kWriteOnlyPrefix{"w"};
294constexpr llvm::StringLiteral kReadOnlyPrefix{"r"};
295
296/// Returns a regex that matches {$rwN}, {$wN}, {$rN}
297static llvm::Regex getPredicateMappingRegex() {
298 llvm::Regex rx(llvm::formatv(R"(\{\$({0}|{1}|{2})([0-9]+)\})",
301 .str());
302 return rx;
303}
304
306 StringRef ptxCode, llvm::SmallDenseSet<unsigned int> &seenRW,
307 llvm::SmallDenseSet<unsigned int> &seenW,
308 llvm::SmallDenseSet<unsigned int> &seenR,
309 llvm::SmallVectorImpl<unsigned int> &rwNums,
310 llvm::SmallVectorImpl<unsigned int> &wNums,
311 llvm::SmallVectorImpl<unsigned int> &rNums) {
312
313 llvm::Regex rx = getPredicateMappingRegex();
314 StringRef rest = ptxCode;
315
316 SmallVector<StringRef, 3> m; // 0: full, 1: kind, 2: number
317 while (!rest.empty() && rx.match(rest, &m)) {
318 unsigned num = 0;
319 (void)m[2].getAsInteger(10, num);
320 // Insert it into the vector only the first time we see this number
321 if (m[1].equals_insensitive(kReadWritePrefix)) {
322 if (seenRW.insert(num).second)
323 rwNums.push_back(num);
324 } else if (m[1].equals_insensitive(kWriteOnlyPrefix)) {
325 if (seenW.insert(num).second)
326 wNums.push_back(num);
327 } else {
328 if (seenR.insert(num).second)
329 rNums.push_back(num);
330 }
331
332 const size_t advance = (size_t)(m[0].data() - rest.data()) + m[0].size();
333 rest = rest.drop_front(advance);
334 }
335}
336
337/// Rewrites `{$rwN}`, `{$wN}`, and `{$rN}` placeholders in `ptxCode` into
338/// compact `$K` indices:
339/// - All `rw*` first (sorted by N),
340/// - Then `w*`,
341/// - Then `r*`.
342/// If there a predicate, it comes always in the end.
343/// Each number is assigned once; duplicates are ignored.
344///
345/// Example Input:
346/// "{
347/// reg .pred p;
348/// setp.ge.s32 p, {$r0}, {$r1};"
349/// selp.s32 {$rw0}, {$r0}, {$r1}, p;
350/// selp.s32 {$rw1}, {$r0}, {$r1}, p;
351/// selp.s32 {$w0}, {$r0}, {$r1}, p;
352/// selp.s32 {$w1}, {$r0}, {$r1}, p;
353/// }\n"
354/// Example Output:
355/// "{
356/// reg .pred p;
357/// setp.ge.s32 p, $4, $5;"
358/// selp.s32 $0, $4, $5, p;
359/// selp.s32 $1, $4, $5, p;
360/// selp.s32 $2, $4, $5, p;
361/// selp.s32 $3, $4, $5, p;
362/// }\n"
363static std::string rewriteAsmPlaceholders(llvm::StringRef ptxCode) {
364 llvm::SmallDenseSet<unsigned> seenRW, seenW, seenR;
365 llvm::SmallVector<unsigned> rwNums, wNums, rNums;
366
367 // Step 1. Count Register Placeholder numbers
368 countPlaceholderNumbers(ptxCode, seenRW, seenW, seenR, rwNums, wNums, rNums);
369
370 // Step 2. Sort the Register Placeholder numbers
371 llvm::sort(rwNums);
372 llvm::sort(wNums);
373 llvm::sort(rNums);
374
375 // Step 3. Create mapping from original to new IDs
376 llvm::DenseMap<unsigned, unsigned> rwMap, wMap, rMap;
377 unsigned nextId = 0;
378 for (unsigned n : rwNums)
379 rwMap[n] = nextId++;
380 for (unsigned n : wNums)
381 wMap[n] = nextId++;
382 for (unsigned n : rNums)
383 rMap[n] = nextId++;
384
385 // Step 4. Rewrite the PTX code with new IDs
386 std::string out;
387 out.reserve(ptxCode.size());
388 size_t prev = 0;
389 StringRef rest = ptxCode;
391 llvm::Regex rx = getPredicateMappingRegex();
392 while (!rest.empty() && rx.match(rest, &matches)) {
393 // Compute absolute match bounds in the original buffer.
394 size_t absStart = (size_t)(matches[0].data() - ptxCode.data());
395 size_t absEnd = absStart + matches[0].size();
396
397 // Emit text before the match.
398 out.append(ptxCode.data() + prev, ptxCode.data() + absStart);
399
400 // Emit compact $K
401 unsigned num = 0;
402 (void)matches[2].getAsInteger(10, num);
403 unsigned id = 0;
404 if (matches[1].equals_insensitive(kReadWritePrefix))
405 id = rwMap.lookup(num);
406 else if (matches[1].equals_insensitive(kWriteOnlyPrefix))
407 id = wMap.lookup(num);
408 else
409 id = rMap.lookup(num);
410
411 out.push_back('$');
412 out += std::to_string(id);
413
414 prev = absEnd;
415
416 const size_t advance =
417 (size_t)(matches[0].data() - rest.data()) + matches[0].size();
418 rest = rest.drop_front(advance);
419 }
420
421 // Step 5. Tail.
422 out.append(ptxCode.data() + prev, ptxCode.data() + ptxCode.size());
423 return out;
424}
425
426LLVM::InlineAsmOp PtxBuilder::build() {
427 auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
428 LLVM::AsmDialect::AD_ATT);
429
431 interfaceOp, needsManualRegisterMapping, registerModifiers, ptxOperands);
432
433 // Remove the last comma from the constraints string.
434 if (!registerConstraints.empty() &&
435 registerConstraints[registerConstraints.size() - 1] == ',')
436 registerConstraints.pop_back();
437 registerConstraints = canonicalizeRegisterConstraints(registerConstraints);
438
439 std::string ptxInstruction = interfaceOp.getPtx();
440 if (!needsManualRegisterMapping)
441 ptxInstruction = rewriteAsmPlaceholders(ptxInstruction);
442
443 // Add the predicate to the asm string.
444 if (interfaceOp.getPredicate().has_value() &&
445 interfaceOp.getPredicate().value()) {
446 std::string predicateStr = "@%";
447 predicateStr += std::to_string((ptxOperands.size() - 1));
448 ptxInstruction = predicateStr + " " + ptxInstruction;
449 }
450
451 // Tablegen doesn't accept $, so we use %, but inline assembly uses $.
452 // Replace all % with $
453 llvm::replace(ptxInstruction, '%', '$');
454
455 return LLVM::InlineAsmOp::create(
456 rewriter, interfaceOp->getLoc(),
457 /*result types=*/resultTypes,
458 /*operands=*/ptxOperands,
459 /*asm_string=*/ptxInstruction,
460 /*constraints=*/registerConstraints.data(),
461 /*has_side_effects=*/interfaceOp.hasSideEffect(),
462 /*is_align_stack=*/false, LLVM::TailCallKind::None,
463 /*asm_dialect=*/asmDialectAttr,
464 /*operand_attrs=*/ArrayAttr());
465}
466
468 LLVM::InlineAsmOp inlineAsmOp = build();
469 LDBG() << "\n Generated PTX \n\t" << inlineAsmOp;
470
471 // Case 0: no result at all → just erase wrapper op.
472 if (!hasResult) {
473 rewriter.eraseOp(interfaceOp);
474 return;
475 }
476
477 if (needsManualRegisterMapping) {
478 rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults());
479 return;
480 }
481
482 // Case 1: Simple path, return single scalar
483 if (!needsPackUnpack(interfaceOp, needsManualRegisterMapping,
484 registerModifiers)) {
485 if (inlineAsmOp->getNumResults() > 0) {
486 rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults());
487 } else {
488 // RW-only case with no declared results: forward the RW value.
489 SmallVector<Value> results;
490 for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
491 if (m == PTXRegisterMod::ReadWrite) {
492 results.push_back(v);
493 break;
494 }
495 rewriter.replaceOp(interfaceOp, results);
496 }
497 return;
498 }
499
500 const bool hasRW = llvm::any_of(registerModifiers, [](PTXRegisterMod m) {
501 return m == PTXRegisterMod::ReadWrite;
502 });
503
504 // All multi-value paths produce a single struct result we need to unpack.
505 assert(LLVM::LLVMStructType::classof(inlineAsmOp.getResultTypes().front()) &&
506 "expected struct return for multi-result inline asm");
507 Value structVal = inlineAsmOp.getResult(0);
508 SmallVector<Value> unpacked =
509 extractStructElements(rewriter, interfaceOp->getLoc(), structVal);
510
511 // Case 2: only declared results (no RW): replace the op with all unpacked.
512 if (!hasRW && interfaceOp->getResults().size() > 0) {
513 rewriter.replaceOp(interfaceOp, unpacked);
514 return;
515 }
516
517 // Case 3: RW-only (no declared results): update RW uses and erase wrapper.
518 if (hasRW && interfaceOp->getResults().size() == 0) {
519 unsigned idx = 0;
520 for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) {
522 continue;
523 Value repl = unpacked[idx++];
524 v.replaceUsesWithIf(repl, [&](OpOperand &use) {
525 Operation *owner = use.getOwner();
526 return owner != interfaceOp && owner != inlineAsmOp;
527 });
528 }
529 rewriter.eraseOp(interfaceOp);
530 return;
531 }
532
533 // Case 4: mixed (RW + declared results).
534 {
535 // First rewrite RW operands in place.
536 unsigned idx = 0;
537 for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) {
539 continue;
540 Value repl = unpacked[idx++];
541 v.replaceUsesWithIf(repl, [&](OpOperand &use) {
542 Operation *owner = use.getOwner();
543 return owner != interfaceOp && owner != inlineAsmOp;
544 });
545 }
546 // The remaining unpacked values correspond to the declared results.
548 tail.reserve(unpacked.size() - idx);
549 for (unsigned i = idx, e = unpacked.size(); i < e; ++i)
550 tail.push_back(unpacked[i]);
551
552 rewriter.replaceOp(interfaceOp, tail);
553 }
554}
return success()
static std::string canonicalizeRegisterConstraints(llvm::StringRef csv)
Canonicalize the register constraints:
constexpr llvm::StringLiteral kWriteOnlyPrefix
static SmallVector< Value > extractStructElements(PatternRewriter &rewriter, Location loc, Value structVal)
Extract every element of a struct value.
static FailureOr< char > getRegisterType(Type type, Location loc)
static constexpr int64_t kSharedMemorySpace
static SmallVector< Type > packResultTypes(BasicPtxBuilderInterface interfaceOp, bool needsManualRegisterMapping, SmallVectorImpl< PTXRegisterMod > &registerModifiers, SmallVectorImpl< Value > &ptxOperands)
Pack the result types of the interface operation.
static std::string rewriteAsmPlaceholders(llvm::StringRef ptxCode)
Rewrites {$rwN}, {$wN}, and {$rN} placeholders in ptxCode into compact $K indices:
static bool needsPackUnpack(BasicPtxBuilderInterface interfaceOp, bool needsManualRegisterMapping, SmallVectorImpl< PTXRegisterMod > &registerModifiers)
Check if the operation needs to pack and unpack results.
constexpr llvm::StringLiteral kReadWritePrefix
constexpr llvm::StringLiteral kReadOnlyPrefix
static llvm::Regex getPredicateMappingRegex()
Returns a regex that matches {$rwN}, {$wN}, {$rN}.
ArrayAttr()
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
LogicalResult insertValue(Value v, PTXRegisterMod itype=PTXRegisterMod::Read)
Add an operand with the read/write input type.
LLVM::InlineAsmOp build()
Builds the inline assembly Op and returns it.
void buildAndReplaceOp()
Shortcut to build the inline assembly Op and replace or erase the original op with.
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
bool isFloat() const
Return true if this is an float type (with the specified width).
Definition Types.cpp:45
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isF16() const
Definition Types.cpp:38
bool isBF16() const
Definition Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
PTXRegisterMod
Register read/write modifier to build constraint string for PTX inline https://docs....
@ Write
Write register with '=' modifier.
@ ReadWrite
ReadWrite register with '+' modifier.
@ Read
Read register with no modifier.
void countPlaceholderNumbers(StringRef ptxCode, llvm::SmallDenseSet< unsigned > &seenRW, llvm::SmallDenseSet< unsigned > &seenW, llvm::SmallDenseSet< unsigned > &seenR, llvm::SmallVectorImpl< unsigned > &rwNums, llvm::SmallVectorImpl< unsigned > &wNums, llvm::SmallVectorImpl< unsigned > &rNums)
Count the number of placeholder variables such as {$r}, {$w}, {$rw} in the PTX code.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.