21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/DebugLog.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include "llvm/Support/LogicalResult.h"
26 #include "llvm/Support/Regex.h"
28 #define DEBUG_TYPE "ptx-builder"
34 #include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.cpp.inc"
47 auto getRegisterTypeForScalar = [&](
Type type) -> FailureOr<char> {
60 if (
auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
69 loc,
"The register type could not be deduced from MLIR type. The ")
71 <<
" is not supported. Supported types are:"
72 "i1, i16, i32, i64, f32, f64,"
73 "pointers.\nPlease use llvm.bitcast if you have different type. "
74 "\nSee the constraints from here: "
75 "https://docs.nvidia.com/cuda/inline-ptx-assembly/"
76 "index.html#constraints";
81 if (
auto v = dyn_cast<VectorType>(type)) {
82 assert(v.getNumDynamicDims() == 0 &&
"Dynamic vectors are not supported");
84 int64_t lanes = v.getNumElements();
85 Type elem = v.getElementType();
89 return getRegisterTypeForScalar(elem);
113 return getRegisterTypeForScalar(widened);
116 return getRegisterTypeForScalar(type);
128 auto structTy = dyn_cast<LLVM::LLVMStructType>(structVal.
getType());
129 assert(structTy &&
"expected LLVM struct");
132 for (
unsigned i : llvm::seq<unsigned>(0, structTy.getBody().size()))
133 elems.push_back(LLVM::ExtractValueOp::create(rewriter, loc, structVal, i));
139 LDBG() << v <<
"\t Modifier : " << itype <<
"\n";
140 registerModifiers.push_back(itype);
142 Location loc = interfaceOp->getLoc();
143 auto getModifier = [&]() ->
const char * {
154 llvm_unreachable(
"Unknown PTX register modifier");
157 auto addValue = [&](
Value v) {
159 ptxOperands.push_back(v);
163 ptxOperands.push_back(v);
167 llvm::raw_string_ostream ss(registerConstraints);
169 if (
auto stype = dyn_cast<LLVM::LLVMStructType>(v.
getType())) {
176 LLVM::ExtractValueOp::create(rewriter, loc, v, idx);
177 addValue(extractValue);
185 "failed to get register type");
186 ss << getModifier() << regType.value() <<
",";
196 ss << getModifier() << regType.value() <<
",";
203 bool needsManualRegisterMapping,
205 if (needsManualRegisterMapping)
207 const unsigned writeOnlyVals = interfaceOp->getNumResults();
208 const unsigned readWriteVals =
212 return (writeOnlyVals + readWriteVals) > 1;
220 bool needsManualRegisterMapping,
224 TypeRange resultRange = interfaceOp->getResultTypes();
227 registerModifiers)) {
229 if (interfaceOp->getResults().size() == 1)
233 for (
auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
239 for (
auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
241 packed.push_back(v.getType());
242 for (
Type t : resultRange)
248 auto sTy = LLVM::LLVMStructType::getLiteral(ctx, packed,
false);
263 csv.split(toks,
',');
264 out.reserve(toks.size() + 8);
266 for (
unsigned i = 0, e = toks.size(); i < e; ++i) {
267 StringRef t = toks[i].trim();
268 if (t.consume_front(
"+")) {
269 plusIdx.push_back(i);
270 out.push_back((
"=" + t).str());
272 out.push_back(t.str());
277 for (
unsigned idx : plusIdx)
278 out.push_back(std::to_string(idx));
282 result.reserve(csv.size() + plusIdx.size() * 2);
283 llvm::raw_string_ostream os(result);
284 for (
size_t i = 0; i < out.size(); ++i) {
298 llvm::Regex rx(llvm::formatv(R
"(\{\$({0}|{1}|{2})([0-9]+)\})",
306 StringRef ptxCode, llvm::SmallDenseSet<unsigned int> &seenRW,
307 llvm::SmallDenseSet<unsigned int> &seenW,
308 llvm::SmallDenseSet<unsigned int> &seenR,
314 StringRef rest = ptxCode;
317 while (!rest.empty() && rx.match(rest, &m)) {
319 (void)m[2].getAsInteger(10, num);
322 if (seenRW.insert(num).second)
323 rwNums.push_back(num);
325 if (seenW.insert(num).second)
326 wNums.push_back(num);
328 if (seenR.insert(num).second)
329 rNums.push_back(num);
332 const size_t advance = (size_t)(m[0].data() - rest.data()) + m[0].size();
333 rest = rest.drop_front(advance);
364 llvm::SmallDenseSet<unsigned> seenRW, seenW, seenR;
378 for (
unsigned n : rwNums)
380 for (
unsigned n : wNums)
382 for (
unsigned n : rNums)
387 out.reserve(ptxCode.size());
389 StringRef rest = ptxCode;
392 while (!rest.empty() && rx.match(rest, &matches)) {
394 size_t absStart = (size_t)(matches[0].data() - ptxCode.data());
395 size_t absEnd = absStart + matches[0].size();
398 out.append(ptxCode.data() + prev, ptxCode.data() + absStart);
402 (void)matches[2].getAsInteger(10, num);
405 id = rwMap.lookup(num);
407 id = wMap.lookup(num);
409 id = rMap.lookup(num);
412 out += std::to_string(
id);
416 const size_t advance =
417 (size_t)(matches[0].data() - rest.data()) + matches[0].size();
418 rest = rest.drop_front(advance);
422 out.append(ptxCode.data() + prev, ptxCode.data() + ptxCode.size());
428 LLVM::AsmDialect::AD_ATT);
431 interfaceOp, needsManualRegisterMapping, registerModifiers, ptxOperands);
434 if (!registerConstraints.empty() &&
435 registerConstraints[registerConstraints.size() - 1] ==
',')
436 registerConstraints.pop_back();
439 std::string ptxInstruction = interfaceOp.getPtx();
440 if (!needsManualRegisterMapping)
444 if (interfaceOp.getPredicate().has_value() &&
445 interfaceOp.getPredicate().value()) {
446 std::string predicateStr =
"@%";
447 predicateStr += std::to_string((ptxOperands.size() - 1));
448 ptxInstruction = predicateStr +
" " + ptxInstruction;
453 llvm::replace(ptxInstruction,
'%',
'$');
455 return LLVM::InlineAsmOp::create(
456 rewriter, interfaceOp->getLoc(),
460 registerConstraints.data(),
461 interfaceOp.hasSideEffect(),
468 LLVM::InlineAsmOp inlineAsmOp =
build();
469 LDBG() <<
"\n Generated PTX \n\t" << inlineAsmOp;
477 if (needsManualRegisterMapping) {
478 rewriter.
replaceOp(interfaceOp, inlineAsmOp->getResults());
484 registerModifiers)) {
485 if (inlineAsmOp->getNumResults() > 0) {
486 rewriter.
replaceOp(interfaceOp, inlineAsmOp->getResults());
490 for (
auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
492 results.push_back(v);
495 rewriter.
replaceOp(interfaceOp, results);
500 const bool hasRW = llvm::any_of(registerModifiers, [](
PTXRegisterMod m) {
505 assert(LLVM::LLVMStructType::classof(inlineAsmOp.getResultTypes().front()) &&
506 "expected struct return for multi-result inline asm");
507 Value structVal = inlineAsmOp.getResult(0);
512 if (!hasRW && interfaceOp->getResults().size() > 0) {
513 rewriter.
replaceOp(interfaceOp, unpacked);
518 if (hasRW && interfaceOp->getResults().size() == 0) {
520 for (
auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) {
523 Value repl = unpacked[idx++];
524 v.replaceUsesWithIf(repl, [&](
OpOperand &use) {
526 return owner != interfaceOp && owner != inlineAsmOp;
537 for (
auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) {
540 Value repl = unpacked[idx++];
541 v.replaceUsesWithIf(repl, [&](
OpOperand &use) {
543 return owner != interfaceOp && owner != inlineAsmOp;
548 tail.reserve(unpacked.size() - idx);
549 for (
unsigned i = idx, e = unpacked.size(); i < e; ++i)
550 tail.push_back(unpacked[i]);
static std::string canonicalizeRegisterConstraints(llvm::StringRef csv)
Canonicalize the register constraints:
constexpr llvm::StringLiteral kWriteOnlyPrefix
static FailureOr< char > getRegisterType(Type type, Location loc)
static constexpr int64_t kSharedMemorySpace
static std::string rewriteAsmPlaceholders(llvm::StringRef ptxCode)
Rewrites {$rwN}, {$wN}, and {$rN} placeholders in ptxCode into compact $K indices:
static bool needsPackUnpack(BasicPtxBuilderInterface interfaceOp, bool needsManualRegisterMapping, SmallVectorImpl< PTXRegisterMod > ®isterModifiers)
Check if the operation needs to pack and unpack results.
static SmallVector< Type > packResultTypes(BasicPtxBuilderInterface interfaceOp, bool needsManualRegisterMapping, SmallVectorImpl< PTXRegisterMod > ®isterModifiers, SmallVectorImpl< Value > &ptxOperands)
Pack the result types of the interface operation.
constexpr llvm::StringLiteral kReadWritePrefix
constexpr llvm::StringLiteral kReadOnlyPrefix
static llvm::Regex getPredicateMappingRegex()
Returns a regex that matches {$rwN}, {$wN}, {$rN}.
static SmallVector< Value > extractStructElements(PatternRewriter &rewriter, Location loc, Value structVal)
Extract every element of a struct value.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
LogicalResult insertValue(Value v, PTXRegisterMod itype=PTXRegisterMod::Read)
Add an operand with the read/write input type.
LLVM::InlineAsmOp build()
Builds the inline assembly Op and returns it.
void buildAndReplaceOp()
Shortcut to build the inline assembly Op and replace or erase the original op with.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat() const
Return true if this is an float type (with the specified width).
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
PTXRegisterMod
Register read/write modifier to build constraint string for PTX inline https://docs....
@ Write
Write register with '=' modifier.
@ ReadWrite
ReadWrite register with '+' modifier.
@ Read
Read register with no modifier.
void countPlaceholderNumbers(StringRef ptxCode, llvm::SmallDenseSet< unsigned > &seenRW, llvm::SmallDenseSet< unsigned > &seenW, llvm::SmallDenseSet< unsigned > &seenR, llvm::SmallVectorImpl< unsigned > &rwNums, llvm::SmallVectorImpl< unsigned > &wNums, llvm::SmallVectorImpl< unsigned > &rNums)
Count the number of placeholder variables such as {$r}, {$w}, {$rw} in the PTX code.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...