21 #include "llvm/ADT/StringExtras.h" 
   22 #include "llvm/ADT/TypeSwitch.h" 
   23 #include "llvm/Support/DebugLog.h" 
   24 #include "llvm/Support/FormatVariadic.h" 
   25 #include "llvm/Support/LogicalResult.h" 
   26 #include "llvm/Support/Regex.h" 
   28 #define DEBUG_TYPE "ptx-builder" 
   34 #include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.cpp.inc" 
   47   auto getRegisterTypeForScalar = [&](
Type type) -> FailureOr<char> {
 
   60     if (
auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
 
   69         loc, 
"The register type could not be deduced from MLIR type. The ")
 
   71         << 
" is not supported. Supported types are:" 
   72            "i1, i16, i32, i64, f32, f64," 
   73            "pointers.\nPlease use llvm.bitcast if you have different type. " 
   74            "\nSee the constraints from here: " 
   75            "https://docs.nvidia.com/cuda/inline-ptx-assembly/" 
   76            "index.html#constraints";
 
   81   if (
auto v = dyn_cast<VectorType>(type)) {
 
   82     assert(v.getNumDynamicDims() == 0 && 
"Dynamic vectors are not supported");
 
   84     int64_t lanes = v.getNumElements();
 
   85     Type elem = v.getElementType();
 
   89       return getRegisterTypeForScalar(elem);
 
  113     return getRegisterTypeForScalar(widened);
 
  116   return getRegisterTypeForScalar(type);
 
  128   auto structTy = dyn_cast<LLVM::LLVMStructType>(structVal.
getType());
 
  129   assert(structTy && 
"expected LLVM struct");
 
  132   for (
unsigned i : llvm::seq<unsigned>(0, structTy.getBody().size()))
 
  133     elems.push_back(LLVM::ExtractValueOp::create(rewriter, loc, structVal, i));
 
  139   LDBG() << v << 
"\t Modifier : " << itype << 
"\n";
 
  140   registerModifiers.push_back(itype);
 
  142   Location loc = interfaceOp->getLoc();
 
  143   auto getModifier = [&]() -> 
const char * {
 
  154     llvm_unreachable(
"Unknown PTX register modifier");
 
  157   auto addValue = [&](
Value v) {
 
  159       ptxOperands.push_back(v);
 
  163       ptxOperands.push_back(v);
 
  167   llvm::raw_string_ostream ss(registerConstraints);
 
  169   if (
auto stype = dyn_cast<LLVM::LLVMStructType>(v.
getType())) {
 
  176             LLVM::ExtractValueOp::create(rewriter, loc, v, idx);
 
  177         addValue(extractValue);
 
  185                                              "failed to get register type");
 
  186         ss << getModifier() << regType.value() << 
",";
 
  196   ss << getModifier() << regType.value() << 
",";
 
  203                 bool needsManualRegisterMapping,
 
  205   if (needsManualRegisterMapping)
 
  207   const unsigned writeOnlyVals = interfaceOp->getNumResults();
 
  208   const unsigned readWriteVals =
 
  212   return (writeOnlyVals + readWriteVals) > 1;
 
  220                 bool needsManualRegisterMapping,
 
  224   TypeRange resultRange = interfaceOp->getResultTypes();
 
  227                        registerModifiers)) {
 
  229     if (interfaceOp->getResults().size() == 1)
 
  233     for (
auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
 
  239   for (
auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
 
  241       packed.push_back(v.getType());
 
  242   for (
Type t : resultRange)
 
  248   auto sTy = LLVM::LLVMStructType::getLiteral(ctx, packed, 
false);
 
  263   csv.split(toks, 
',');
 
  264   out.reserve(toks.size() + 8);
 
  266   for (
unsigned i = 0, e = toks.size(); i < e; ++i) {
 
  267     StringRef t = toks[i].trim();
 
  268     if (t.consume_front(
"+")) {
 
  269       plusIdx.push_back(i);
 
  270       out.push_back((
"=" + t).str());
 
  272       out.push_back(t.str());
 
  277   for (
unsigned idx : plusIdx)
 
  278     out.push_back(std::to_string(idx));
 
  282   result.reserve(csv.size() + plusIdx.size() * 2);
 
  283   llvm::raw_string_ostream os(result);
 
  284   for (
size_t i = 0; i < out.size(); ++i) {
 
  298   llvm::Regex rx(llvm::formatv(R
"(\{\$({0}|{1}|{2})([0-9]+)\})", 
  306     StringRef ptxCode, llvm::SmallDenseSet<unsigned int> &seenRW,
 
  307     llvm::SmallDenseSet<unsigned int> &seenW,
 
  308     llvm::SmallDenseSet<unsigned int> &seenR,
 
  314   StringRef rest = ptxCode;
 
  317   while (!rest.empty() && rx.match(rest, &m)) {
 
  319     (void)m[2].getAsInteger(10, num);
 
  322       if (seenRW.insert(num).second)
 
  323         rwNums.push_back(num);
 
  325       if (seenW.insert(num).second)
 
  326         wNums.push_back(num);
 
  328       if (seenR.insert(num).second)
 
  329         rNums.push_back(num);
 
  332     const size_t advance = (size_t)(m[0].data() - rest.data()) + m[0].size();
 
  333     rest = rest.drop_front(advance);
 
  364   llvm::SmallDenseSet<unsigned> seenRW, seenW, seenR;
 
  378   for (
unsigned n : rwNums)
 
  380   for (
unsigned n : wNums)
 
  382   for (
unsigned n : rNums)
 
  387   out.reserve(ptxCode.size());
 
  389   StringRef rest = ptxCode;
 
  392   while (!rest.empty() && rx.match(rest, &matches)) {
 
  394     size_t absStart = (size_t)(matches[0].data() - ptxCode.data());
 
  395     size_t absEnd = absStart + matches[0].size();
 
  398     out.append(ptxCode.data() + prev, ptxCode.data() + absStart);
 
  402     (void)matches[2].getAsInteger(10, num);
 
  405       id = rwMap.lookup(num);
 
  407       id = wMap.lookup(num);
 
  409       id = rMap.lookup(num);
 
  412     out += std::to_string(
id);
 
  416     const size_t advance =
 
  417         (size_t)(matches[0].data() - rest.data()) + matches[0].size();
 
  418     rest = rest.drop_front(advance);
 
  422   out.append(ptxCode.data() + prev, ptxCode.data() + ptxCode.size());
 
  428                                                   LLVM::AsmDialect::AD_ATT);
 
  431       interfaceOp, needsManualRegisterMapping, registerModifiers, ptxOperands);
 
  434   if (!registerConstraints.empty() &&
 
  435       registerConstraints[registerConstraints.size() - 1] == 
',')
 
  436     registerConstraints.pop_back();
 
  439   std::string ptxInstruction = interfaceOp.getPtx();
 
  440   if (!needsManualRegisterMapping)
 
  444   if (interfaceOp.getPredicate().has_value() &&
 
  445       interfaceOp.getPredicate().value()) {
 
  446     std::string predicateStr = 
"@%";
 
  447     predicateStr += std::to_string((ptxOperands.size() - 1));
 
  448     ptxInstruction = predicateStr + 
" " + ptxInstruction;
 
  453   llvm::replace(ptxInstruction, 
'%', 
'$');
 
  455   return LLVM::InlineAsmOp::create(
 
  456       rewriter, interfaceOp->getLoc(),
 
  460       registerConstraints.data(),
 
  461       interfaceOp.hasSideEffect(),
 
  468   LLVM::InlineAsmOp inlineAsmOp = 
build();
 
  469   LDBG() << 
"\n Generated PTX \n\t" << inlineAsmOp;
 
  477   if (needsManualRegisterMapping) {
 
  478     rewriter.
replaceOp(interfaceOp, inlineAsmOp->getResults());
 
  484                        registerModifiers)) {
 
  485     if (inlineAsmOp->getNumResults() > 0) {
 
  486       rewriter.
replaceOp(interfaceOp, inlineAsmOp->getResults());
 
  490       for (
auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
 
  492           results.push_back(v);
 
  495       rewriter.
replaceOp(interfaceOp, results);
 
  500   const bool hasRW = llvm::any_of(registerModifiers, [](
PTXRegisterMod m) {
 
  505   assert(LLVM::LLVMStructType::classof(inlineAsmOp.getResultTypes().front()) &&
 
  506          "expected struct return for multi-result inline asm");
 
  507   Value structVal = inlineAsmOp.getResult(0);
 
  512   if (!hasRW && interfaceOp->getResults().size() > 0) {
 
  513     rewriter.
replaceOp(interfaceOp, unpacked);
 
  518   if (hasRW && interfaceOp->getResults().size() == 0) {
 
  520     for (
auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) {
 
  523       Value repl = unpacked[idx++];
 
  524       v.replaceUsesWithIf(repl, [&](
OpOperand &use) {
 
  526         return owner != interfaceOp && owner != inlineAsmOp;
 
  537     for (
auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) {
 
  540       Value repl = unpacked[idx++];
 
  541       v.replaceUsesWithIf(repl, [&](
OpOperand &use) {
 
  543         return owner != interfaceOp && owner != inlineAsmOp;
 
  548     tail.reserve(unpacked.size() - idx);
 
  549     for (
unsigned i = idx, e = unpacked.size(); i < e; ++i)
 
  550       tail.push_back(unpacked[i]);
 
static std::string canonicalizeRegisterConstraints(llvm::StringRef csv)
Canonicalize the register constraints:
constexpr llvm::StringLiteral kWriteOnlyPrefix
static FailureOr< char > getRegisterType(Type type, Location loc)
static constexpr int64_t kSharedMemorySpace
static std::string rewriteAsmPlaceholders(llvm::StringRef ptxCode)
Rewrites {$rwN}, {$wN}, and {$rN} placeholders in ptxCode into compact $K indices:
static bool needsPackUnpack(BasicPtxBuilderInterface interfaceOp, bool needsManualRegisterMapping, SmallVectorImpl< PTXRegisterMod > ®isterModifiers)
Check if the operation needs to pack and unpack results.
static SmallVector< Type > packResultTypes(BasicPtxBuilderInterface interfaceOp, bool needsManualRegisterMapping, SmallVectorImpl< PTXRegisterMod > ®isterModifiers, SmallVectorImpl< Value > &ptxOperands)
Pack the result types of the interface operation.
constexpr llvm::StringLiteral kReadWritePrefix
constexpr llvm::StringLiteral kReadOnlyPrefix
static llvm::Regex getPredicateMappingRegex()
Returns a regex that matches {$rwN}, {$wN}, {$rN}.
static SmallVector< Value > extractStructElements(PatternRewriter &rewriter, Location loc, Value structVal)
Extract every element of a struct value.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
LogicalResult insertValue(Value v, PTXRegisterMod itype=PTXRegisterMod::Read)
Add an operand with the read/write input type.
LLVM::InlineAsmOp build()
Builds the inline assembly Op and returns it.
void buildAndReplaceOp()
Shortcut to build the inline assembly Op and replace or erase the original op with.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat() const
Return true if this is an float type (with the specified width).
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
PTXRegisterMod
Register read/write modifier to build constraint string for PTX inline https://docs....
@ Write
Write register with '=' modifier.
@ ReadWrite
ReadWrite register with '+' modifier.
@ Read
Read register with no modifier.
void countPlaceholderNumbers(StringRef ptxCode, llvm::SmallDenseSet< unsigned > &seenRW, llvm::SmallDenseSet< unsigned > &seenW, llvm::SmallDenseSet< unsigned > &seenR, llvm::SmallVectorImpl< unsigned > &rwNums, llvm::SmallVectorImpl< unsigned > &wNums, llvm::SmallVectorImpl< unsigned > &rNums)
Count the number of placeholder variables such as {$r}, {$w}, {$rw} in the PTX code.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...