MLIR  17.0.0git
MemRefToLLVM.cpp
Go to the documentation of this file.
1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/IRMapping.h"
24 #include "mlir/Pass/Pass.h"
26 #include "llvm/ADT/SmallBitVector.h"
27 #include <optional>
28 
29 namespace mlir {
30 #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
31 #include "mlir/Conversion/Passes.h.inc"
32 } // namespace mlir
33 
34 using namespace mlir;
35 
36 namespace {
37 
38 bool isStaticStrideOrOffset(int64_t strideOrOffset) {
39  return !ShapedType::isDynamic(strideOrOffset);
40 }
41 
42 LLVM::LLVMFuncOp getFreeFn(LLVMTypeConverter *typeConverter, ModuleOp module) {
43  bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
44 
45  if (useGenericFn)
47  module, typeConverter->useOpaquePointers());
48 
49  return LLVM::lookupOrCreateFreeFn(module, typeConverter->useOpaquePointers());
50 }
51 
52 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
53  AllocOpLowering(LLVMTypeConverter &converter)
54  : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
55  converter) {}
56  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
57  Location loc, Value sizeBytes,
58  Operation *op) const override {
59  return allocateBufferManuallyAlign(
60  rewriter, loc, sizeBytes, op,
61  getAlignment(rewriter, loc, cast<memref::AllocOp>(op)));
62  }
63 };
64 
65 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
66  AlignedAllocOpLowering(LLVMTypeConverter &converter)
67  : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
68  converter) {}
69  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
70  Location loc, Value sizeBytes,
71  Operation *op) const override {
72  Value ptr = allocateBufferAutoAlign(
73  rewriter, loc, sizeBytes, op, &defaultLayout,
74  alignedAllocationGetAlignment(rewriter, loc, cast<memref::AllocOp>(op),
75  &defaultLayout));
76  return std::make_tuple(ptr, ptr);
77  }
78 
79 private:
80  /// Default layout to use in absence of the corresponding analysis.
81  DataLayout defaultLayout;
82 };
83 
84 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
85  AllocaOpLowering(LLVMTypeConverter &converter)
86  : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
87  converter) {}
88 
89  /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
90  /// is set to null for stack allocations. `accessAlignment` is set if
91  /// alignment is needed post allocation (for eg. in conjunction with malloc).
92  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
93  Location loc, Value sizeBytes,
94  Operation *op) const override {
95 
96  // With alloca, one gets a pointer to the element type right away.
97  // For stack allocations.
98  auto allocaOp = cast<memref::AllocaOp>(op);
99  auto elementType =
100  typeConverter->convertType(allocaOp.getType().getElementType());
101  unsigned addrSpace =
102  *getTypeConverter()->getMemRefAddressSpace(allocaOp.getType());
103  auto elementPtrType =
104  getTypeConverter()->getPointerType(elementType, addrSpace);
105 
106  auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
107  loc, elementPtrType, elementType, sizeBytes,
108  allocaOp.getAlignment().value_or(0));
109 
110  return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
111  }
112 };
113 
114 /// The base class for lowering realloc op, to support the implementation of
115 /// realloc via allocation methods that may or may not support alignment.
116 /// A derived class should provide an implementation of allocateBuffer using
117 /// the underline allocation methods.
118 struct ReallocOpLoweringBase : public AllocationOpLLVMLowering {
119  using OpAdaptor = typename memref::ReallocOp::Adaptor;
120 
121  ReallocOpLoweringBase(LLVMTypeConverter &converter)
122  : AllocationOpLLVMLowering(memref::ReallocOp::getOperationName(),
123  converter) {}
124 
125  /// Allocates the new buffer. Returns the allocated pointer and the
126  /// aligned pointer.
127  virtual std::tuple<Value, Value>
128  allocateBuffer(ConversionPatternRewriter &rewriter, Location loc,
129  Value sizeBytes, memref::ReallocOp op) const = 0;
130 
132  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
133  ConversionPatternRewriter &rewriter) const final {
134  return matchAndRewrite(cast<memref::ReallocOp>(op),
135  OpAdaptor(operands, op->getAttrDictionary()),
136  rewriter);
137  }
138 
139  // A `realloc` is converted as follows:
140  // If new_size > old_size
141  // 1. allocates a new buffer
142  // 2. copies the content of the old buffer to the new buffer
143  // 3. release the old buffer
144  // 3. updates the buffer pointers in the memref descriptor
145  // Update the size in the memref descriptor
146  // Alignment request is handled by allocating `alignment` more bytes than
147  // requested and shifting the aligned pointer relative to the allocated
148  // memory.
149  LogicalResult matchAndRewrite(memref::ReallocOp op, OpAdaptor adaptor,
150  ConversionPatternRewriter &rewriter) const {
151  OpBuilder::InsertionGuard guard(rewriter);
152  Location loc = op.getLoc();
153 
154  auto computeNumElements =
155  [&](MemRefType type, function_ref<Value()> getDynamicSize) -> Value {
156  // Compute number of elements.
157  int64_t size = type.getShape()[0];
158  Value numElements = ((size == ShapedType::kDynamic)
159  ? getDynamicSize()
160  : createIndexConstant(rewriter, loc, size));
161  Type indexType = getIndexType();
162  if (numElements.getType() != indexType)
163  numElements = typeConverter->materializeTargetConversion(
164  rewriter, loc, indexType, numElements);
165  return numElements;
166  };
167 
168  MemRefDescriptor desc(adaptor.getSource());
169  Value oldDesc = desc;
170 
171  // Split the block right before the current op into two blocks.
172  Block *currentBlock = rewriter.getInsertionBlock();
173  Block *block =
174  rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
175  // Add a block argument by creating an empty block with the argument type
176  // and then merging the block into the empty block.
177  Block *endBlock = rewriter.createBlock(
178  block->getParent(), Region::iterator(block), oldDesc.getType(), loc);
179  rewriter.mergeBlocks(block, endBlock, {});
180  // Add a new block for the true branch of the conditional statement we will
181  // add.
182  Block *trueBlock = rewriter.createBlock(
183  currentBlock->getParent(), std::next(Region::iterator(currentBlock)));
184 
185  rewriter.setInsertionPointToEnd(currentBlock);
186  Value src = op.getSource();
187  auto srcType = src.getType().dyn_cast<MemRefType>();
188  Value srcNumElements = computeNumElements(
189  srcType, [&]() -> Value { return desc.size(rewriter, loc, 0); });
190  auto dstType = op.getType().cast<MemRefType>();
191  Value dstNumElements = computeNumElements(
192  dstType, [&]() -> Value { return op.getDynamicResultSize(); });
193  Value cond = rewriter.create<LLVM::ICmpOp>(
194  loc, IntegerType::get(rewriter.getContext(), 1),
195  LLVM::ICmpPredicate::ugt, dstNumElements, srcNumElements);
196  rewriter.create<LLVM::CondBrOp>(loc, cond, trueBlock, ArrayRef<Value>(),
197  endBlock, ValueRange{oldDesc});
198 
199  rewriter.setInsertionPointToStart(trueBlock);
200  Value sizeInBytes = getSizeInBytes(loc, dstType.getElementType(), rewriter);
201  // Compute total byte size.
202  auto dstByteSize =
203  rewriter.create<LLVM::MulOp>(loc, dstNumElements, sizeInBytes);
204  // Since the src and dst memref are guarantee to have the same
205  // element type by the verifier, it is safe here to reuse the
206  // type size computed from dst memref.
207  auto srcByteSize =
208  rewriter.create<LLVM::MulOp>(loc, srcNumElements, sizeInBytes);
209  // Allocate a new buffer.
210  auto [dstRawPtr, dstAlignedPtr] =
211  allocateBuffer(rewriter, loc, dstByteSize, op);
212  // Copy the data from the old buffer to the new buffer.
213  Value srcAlignedPtr = desc.alignedPtr(rewriter, loc);
214  Value isVolatile =
215  rewriter.create<LLVM::ConstantOp>(loc, rewriter.getBoolAttr(false));
216  auto toVoidPtr = [&](Value ptr) -> Value {
217  if (getTypeConverter()->useOpaquePointers())
218  return ptr;
219  return rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr);
220  };
221  rewriter.create<LLVM::MemcpyOp>(loc, toVoidPtr(dstAlignedPtr),
222  toVoidPtr(srcAlignedPtr), srcByteSize,
223  isVolatile);
224  // Deallocate the old buffer.
225  LLVM::LLVMFuncOp freeFunc =
226  getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
227  rewriter.create<LLVM::CallOp>(loc, freeFunc,
228  toVoidPtr(desc.allocatedPtr(rewriter, loc)));
229  // Replace the old buffer addresses in the MemRefDescriptor with the new
230  // buffer addresses.
231  desc.setAllocatedPtr(rewriter, loc, dstRawPtr);
232  desc.setAlignedPtr(rewriter, loc, dstAlignedPtr);
233  rewriter.create<LLVM::BrOp>(loc, Value(desc), endBlock);
234 
235  rewriter.setInsertionPoint(op);
236  // Update the memref size.
237  MemRefDescriptor newDesc(endBlock->getArgument(0));
238  newDesc.setSize(rewriter, loc, 0, dstNumElements);
239  rewriter.replaceOp(op, {newDesc});
240  return success();
241  }
242 
243 private:
245 };
246 
247 struct ReallocOpLowering : public ReallocOpLoweringBase {
248  ReallocOpLowering(LLVMTypeConverter &converter)
249  : ReallocOpLoweringBase(converter) {}
250  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
251  Location loc, Value sizeBytes,
252  memref::ReallocOp op) const override {
253  return allocateBufferManuallyAlign(rewriter, loc, sizeBytes, op,
254  getAlignment(rewriter, loc, op));
255  }
256 };
257 
258 struct AlignedReallocOpLowering : public ReallocOpLoweringBase {
259  AlignedReallocOpLowering(LLVMTypeConverter &converter)
260  : ReallocOpLoweringBase(converter) {}
261  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
262  Location loc, Value sizeBytes,
263  memref::ReallocOp op) const override {
264  Value ptr = allocateBufferAutoAlign(
265  rewriter, loc, sizeBytes, op, &defaultLayout,
266  alignedAllocationGetAlignment(rewriter, loc, op, &defaultLayout));
267  return std::make_tuple(ptr, ptr);
268  }
269 
270 private:
271  /// Default layout to use in absence of the corresponding analysis.
272  DataLayout defaultLayout;
273 };
274 
275 struct AllocaScopeOpLowering
276  : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
278 
280  matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
281  ConversionPatternRewriter &rewriter) const override {
282  OpBuilder::InsertionGuard guard(rewriter);
283  Location loc = allocaScopeOp.getLoc();
284 
285  // Split the current block before the AllocaScopeOp to create the inlining
286  // point.
287  auto *currentBlock = rewriter.getInsertionBlock();
288  auto *remainingOpsBlock =
289  rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
290  Block *continueBlock;
291  if (allocaScopeOp.getNumResults() == 0) {
292  continueBlock = remainingOpsBlock;
293  } else {
294  continueBlock = rewriter.createBlock(
295  remainingOpsBlock, allocaScopeOp.getResultTypes(),
296  SmallVector<Location>(allocaScopeOp->getNumResults(),
297  allocaScopeOp.getLoc()));
298  rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
299  }
300 
301  // Inline body region.
302  Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
303  Block *afterBody = &allocaScopeOp.getBodyRegion().back();
304  rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
305 
306  // Save stack and then branch into the body of the region.
307  rewriter.setInsertionPointToEnd(currentBlock);
308  auto stackSaveOp =
309  rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
310  rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
311 
312  // Replace the alloca_scope return with a branch that jumps out of the body.
313  // Stack restore before leaving the body region.
314  rewriter.setInsertionPointToEnd(afterBody);
315  auto returnOp =
316  cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
317  auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
318  returnOp, returnOp.getResults(), continueBlock);
319 
320  // Insert stack restore before jumping out the body of the region.
321  rewriter.setInsertionPoint(branchOp);
322  rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
323 
324  // Replace the op with values return from the body region.
325  rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
326 
327  return success();
328  }
329 };
330 
331 struct AssumeAlignmentOpLowering
332  : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
334  memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
335 
337  matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
338  ConversionPatternRewriter &rewriter) const override {
339  Value memref = adaptor.getMemref();
340  unsigned alignment = op.getAlignment();
341  auto loc = op.getLoc();
342 
343  MemRefDescriptor memRefDescriptor(memref);
344  Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
345 
346  // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
347  // the asserted memref.alignedPtr isn't used anywhere else, as the real
348  // users like load/store/views always re-extract memref.alignedPtr as they
349  // get lowered.
350  //
351  // This relies on LLVM's CSE optimization (potentially after SROA), since
352  // after CSE all memref.alignedPtr instances get de-duplicated into the same
353  // pointer SSA value.
354  auto intPtrType =
355  getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
356  Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
357  Value mask =
358  createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
359  Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
360  rewriter.create<LLVM::AssumeOp>(
361  loc, rewriter.create<LLVM::ICmpOp>(
362  loc, LLVM::ICmpPredicate::eq,
363  rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
364 
365  rewriter.eraseOp(op);
366  return success();
367  }
368 };
369 
370 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
371 // The memref descriptor being an SSA value, there is no need to clean it up
372 // in any way.
373 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
375 
376  explicit DeallocOpLowering(LLVMTypeConverter &converter)
377  : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
378 
380  matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
381  ConversionPatternRewriter &rewriter) const override {
382  // Insert the `free` declaration if it is not already present.
383  LLVM::LLVMFuncOp freeFunc =
384  getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
385  Value allocatedPtr;
386  if (auto unrankedTy =
387  llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
388  Type elementType = unrankedTy.getElementType();
389  Type llvmElementTy = getTypeConverter()->convertType(elementType);
390  LLVM::LLVMPointerType elementPtrTy = getTypeConverter()->getPointerType(
391  llvmElementTy, unrankedTy.getMemorySpaceAsInt());
393  rewriter, op.getLoc(),
394  UnrankedMemRefDescriptor(adaptor.getMemref())
395  .memRefDescPtr(rewriter, op.getLoc()),
396  elementPtrTy);
397  } else {
398  allocatedPtr = MemRefDescriptor(adaptor.getMemref())
399  .allocatedPtr(rewriter, op.getLoc());
400  }
401  if (!getTypeConverter()->useOpaquePointers())
402  allocatedPtr = rewriter.create<LLVM::BitcastOp>(
403  op.getLoc(), getVoidPtrType(), allocatedPtr);
404 
405  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc, allocatedPtr);
406  return success();
407  }
408 };
409 
410 // A `dim` is converted to a constant for static sizes and to an access to the
411 // size stored in the memref descriptor for dynamic sizes.
412 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
414 
416  matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
417  ConversionPatternRewriter &rewriter) const override {
418  Type operandType = dimOp.getSource().getType();
419  if (operandType.isa<UnrankedMemRefType>()) {
420  FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
421  operandType, dimOp, adaptor.getOperands(), rewriter);
422  if (failed(extractedSize))
423  return failure();
424  rewriter.replaceOp(dimOp, {*extractedSize});
425  return success();
426  }
427  if (operandType.isa<MemRefType>()) {
428  rewriter.replaceOp(
429  dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
430  adaptor.getOperands(), rewriter)});
431  return success();
432  }
433  llvm_unreachable("expected MemRefType or UnrankedMemRefType");
434  }
435 
436 private:
438  extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
439  OpAdaptor adaptor,
440  ConversionPatternRewriter &rewriter) const {
441  Location loc = dimOp.getLoc();
442 
443  auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
444  auto scalarMemRefType =
445  MemRefType::get({}, unrankedMemRefType.getElementType());
446  FailureOr<unsigned> maybeAddressSpace =
447  getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
448  if (failed(maybeAddressSpace)) {
449  dimOp.emitOpError("memref memory space must be convertible to an integer "
450  "address space");
451  return failure();
452  }
453  unsigned addressSpace = *maybeAddressSpace;
454 
455  // Extract pointer to the underlying ranked descriptor and bitcast it to a
456  // memref<element_type> descriptor pointer to minimize the number of GEP
457  // operations.
458  UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
459  Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
460 
461  Type elementType = typeConverter->convertType(scalarMemRefType);
462  Value scalarMemRefDescPtr;
463  if (getTypeConverter()->useOpaquePointers())
464  scalarMemRefDescPtr = underlyingRankedDesc;
465  else
466  scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
467  loc, LLVM::LLVMPointerType::get(elementType, addressSpace),
468  underlyingRankedDesc);
469 
470  // Get pointer to offset field of memref<element_type> descriptor.
471  Type indexPtrTy = getTypeConverter()->getPointerType(
472  getTypeConverter()->getIndexType(), addressSpace);
473  Value offsetPtr = rewriter.create<LLVM::GEPOp>(
474  loc, indexPtrTy, elementType, scalarMemRefDescPtr,
475  ArrayRef<LLVM::GEPArg>{0, 2});
476 
477  // The size value that we have to extract can be obtained using GEPop with
478  // `dimOp.index() + 1` index argument.
479  Value idxPlusOne = rewriter.create<LLVM::AddOp>(
480  loc, createIndexConstant(rewriter, loc, 1), adaptor.getIndex());
481  Value sizePtr = rewriter.create<LLVM::GEPOp>(
482  loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
483  idxPlusOne);
484  return rewriter
485  .create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
486  .getResult();
487  }
488 
489  std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
490  if (auto idx = dimOp.getConstantIndex())
491  return idx;
492 
493  if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
494  return constantOp.getValue()
495  .cast<IntegerAttr>()
496  .getValue()
497  .getSExtValue();
498 
499  return std::nullopt;
500  }
501 
502  Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
503  OpAdaptor adaptor,
504  ConversionPatternRewriter &rewriter) const {
505  Location loc = dimOp.getLoc();
506 
507  // Take advantage if index is constant.
508  MemRefType memRefType = operandType.cast<MemRefType>();
509  if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
510  int64_t i = *index;
511  if (i >= 0 && i < memRefType.getRank()) {
512  if (memRefType.isDynamicDim(i)) {
513  // extract dynamic size from the memref descriptor.
514  MemRefDescriptor descriptor(adaptor.getSource());
515  return descriptor.size(rewriter, loc, i);
516  }
517  // Use constant for static size.
518  int64_t dimSize = memRefType.getDimSize(i);
519  return createIndexConstant(rewriter, loc, dimSize);
520  }
521  }
522  Value index = adaptor.getIndex();
523  int64_t rank = memRefType.getRank();
524  MemRefDescriptor memrefDescriptor(adaptor.getSource());
525  return memrefDescriptor.size(rewriter, loc, index, rank);
526  }
527 };
528 
529 /// Common base for load and store operations on MemRefs. Restricts the match
530 /// to supported MemRef types. Provides functionality to emit code accessing a
531 /// specific element of the underlying data buffer.
532 template <typename Derived>
533 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
536  using Base = LoadStoreOpLowering<Derived>;
537 
538  LogicalResult match(Derived op) const override {
539  MemRefType type = op.getMemRefType();
540  return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
541  }
542 };
543 
544 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
545 /// retried until it succeeds in atomically storing a new value into memory.
546 ///
547 /// +---------------------------------+
548 /// | <code before the AtomicRMWOp> |
549 /// | <compute initial %loaded> |
550 /// | cf.br loop(%loaded) |
551 /// +---------------------------------+
552 /// |
553 /// -------| |
554 /// | v v
555 /// | +--------------------------------+
556 /// | | loop(%loaded): |
557 /// | | <body contents> |
558 /// | | %pair = cmpxchg |
559 /// | | %ok = %pair[0] |
560 /// | | %new = %pair[1] |
561 /// | | cf.cond_br %ok, end, loop(%new) |
562 /// | +--------------------------------+
563 /// | | |
564 /// |----------- |
565 /// v
566 /// +--------------------------------+
567 /// | end: |
568 /// | <code after the AtomicRMWOp> |
569 /// +--------------------------------+
570 ///
571 struct GenericAtomicRMWOpLowering
572  : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
573  using Base::Base;
574 
576  matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
577  ConversionPatternRewriter &rewriter) const override {
578  auto loc = atomicOp.getLoc();
579  Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
580 
581  // Split the block into initial, loop, and ending parts.
582  auto *initBlock = rewriter.getInsertionBlock();
583  auto *loopBlock = rewriter.splitBlock(initBlock, Block::iterator(atomicOp));
584  loopBlock->addArgument(valueType, loc);
585 
586  auto *endBlock =
587  rewriter.splitBlock(loopBlock, Block::iterator(atomicOp)++);
588 
589  // Compute the loaded value and branch to the loop block.
590  rewriter.setInsertionPointToEnd(initBlock);
591  auto memRefType = atomicOp.getMemref().getType().cast<MemRefType>();
592  auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
593  adaptor.getIndices(), rewriter);
594  Value init = rewriter.create<LLVM::LoadOp>(
595  loc, typeConverter->convertType(memRefType.getElementType()), dataPtr);
596  rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
597 
598  // Prepare the body of the loop block.
599  rewriter.setInsertionPointToStart(loopBlock);
600 
601  // Clone the GenericAtomicRMWOp region and extract the result.
602  auto loopArgument = loopBlock->getArgument(0);
603  IRMapping mapping;
604  mapping.map(atomicOp.getCurrentValue(), loopArgument);
605  Block &entryBlock = atomicOp.body().front();
606  for (auto &nestedOp : entryBlock.without_terminator()) {
607  Operation *clone = rewriter.clone(nestedOp, mapping);
608  mapping.map(nestedOp.getResults(), clone->getResults());
609  }
610  Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
611 
612  // Prepare the epilog of the loop block.
613  // Append the cmpxchg op to the end of the loop block.
614  auto successOrdering = LLVM::AtomicOrdering::acq_rel;
615  auto failureOrdering = LLVM::AtomicOrdering::monotonic;
616  auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
617  loc, dataPtr, loopArgument, result, successOrdering, failureOrdering);
618  // Extract the %new_loaded and %ok values from the pair.
619  Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 0);
620  Value ok = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 1);
621 
622  // Conditionally branch to the end or back to the loop depending on %ok.
623  rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
624  loopBlock, newLoaded);
625 
626  rewriter.setInsertionPointToEnd(endBlock);
627 
628  // The 'result' of the atomic_rmw op is the newly loaded value.
629  rewriter.replaceOp(atomicOp, {newLoaded});
630 
631  return success();
632  }
633 };
634 
635 /// Returns the LLVM type of the global variable given the memref type `type`.
636 static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
637  LLVMTypeConverter &typeConverter) {
638  // LLVM type for a global memref will be a multi-dimension array. For
639  // declarations or uninitialized global memrefs, we can potentially flatten
640  // this to a 1D array. However, for memref.global's with an initial value,
641  // we do not intend to flatten the ElementsAttribute when going from std ->
642  // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
643  Type elementType = typeConverter.convertType(type.getElementType());
644  Type arrayTy = elementType;
645  // Shape has the outermost dim at index 0, so need to walk it backwards
646  for (int64_t dim : llvm::reverse(type.getShape()))
647  arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
648  return arrayTy;
649 }
650 
651 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
652 struct GlobalMemrefOpLowering
653  : public ConvertOpToLLVMPattern<memref::GlobalOp> {
655 
657  matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
658  ConversionPatternRewriter &rewriter) const override {
659  MemRefType type = global.getType();
660  if (!isConvertibleAndHasIdentityMaps(type))
661  return failure();
662 
663  Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
664 
665  LLVM::Linkage linkage =
666  global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
667 
668  Attribute initialValue = nullptr;
669  if (!global.isExternal() && !global.isUninitialized()) {
670  auto elementsAttr = global.getInitialValue()->cast<ElementsAttr>();
671  initialValue = elementsAttr;
672 
673  // For scalar memrefs, the global variable created is of the element type,
674  // so unpack the elements attribute to extract the value.
675  if (type.getRank() == 0)
676  initialValue = elementsAttr.getSplatValue<Attribute>();
677  }
678 
679  uint64_t alignment = global.getAlignment().value_or(0);
680  FailureOr<unsigned> addressSpace =
681  getTypeConverter()->getMemRefAddressSpace(type);
682  if (failed(addressSpace))
683  return global.emitOpError(
684  "memory space cannot be converted to an integer address space");
685  auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
686  global, arrayTy, global.getConstant(), linkage, global.getSymName(),
687  initialValue, alignment, *addressSpace);
688  if (!global.isExternal() && global.isUninitialized()) {
689  Block *blk = new Block();
690  newGlobal.getInitializerRegion().push_back(blk);
691  rewriter.setInsertionPointToStart(blk);
692  Value undef[] = {
693  rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
694  rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
695  }
696  return success();
697  }
698 };
699 
700 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
701 /// the first element stashed into the descriptor. This reuses
702 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
703 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
704  GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
705  : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
706  converter) {}
707 
708  /// Buffer "allocation" for memref.get_global op is getting the address of
709  /// the global variable referenced.
710  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
711  Location loc, Value sizeBytes,
712  Operation *op) const override {
713  auto getGlobalOp = cast<memref::GetGlobalOp>(op);
714  MemRefType type = getGlobalOp.getResult().getType().cast<MemRefType>();
715 
716  // This is called after a type conversion, which would have failed if this
717  // call fails.
718  unsigned memSpace = *getTypeConverter()->getMemRefAddressSpace(type);
719 
720  Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
721  Type resTy = getTypeConverter()->getPointerType(arrayTy, memSpace);
722  auto addressOf =
723  rewriter.create<LLVM::AddressOfOp>(loc, resTy, getGlobalOp.getName());
724 
725  // Get the address of the first element in the array by creating a GEP with
726  // the address of the GV as the base, and (rank + 1) number of 0 indices.
727  Type elementType = typeConverter->convertType(type.getElementType());
728  Type elementPtrType =
729  getTypeConverter()->getPointerType(elementType, memSpace);
730 
731  auto gep = rewriter.create<LLVM::GEPOp>(
732  loc, elementPtrType, arrayTy, addressOf,
733  SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));
734 
735  // We do not expect the memref obtained using `memref.get_global` to be
736  // ever deallocated. Set the allocated pointer to be known bad value to
737  // help debug if that ever happens.
738  auto intPtrType = getIntPtrType(memSpace);
739  Value deadBeefConst =
740  createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
741  auto deadBeefPtr =
742  rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
743 
744  // Both allocated and aligned pointers are same. We could potentially stash
745  // a nullptr for the allocated pointer since we do not expect any dealloc.
746  return std::make_tuple(deadBeefPtr, gep);
747  }
748 };
749 
750 // Load operation is lowered to obtaining a pointer to the indexed element
751 // and loading it.
752 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
753  using Base::Base;
754 
756  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
757  ConversionPatternRewriter &rewriter) const override {
758  auto type = loadOp.getMemRefType();
759 
760  Value dataPtr =
761  getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
762  adaptor.getIndices(), rewriter);
763  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
764  loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0,
765  false, loadOp.getNontemporal());
766  return success();
767  }
768 };
769 
770 // Store operation is lowered to obtaining a pointer to the indexed element,
771 // and storing the given value to it.
772 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
773  using Base::Base;
774 
776  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
777  ConversionPatternRewriter &rewriter) const override {
778  auto type = op.getMemRefType();
779 
780  Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
781  adaptor.getIndices(), rewriter);
782  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
783  0, false, op.getNontemporal());
784  return success();
785  }
786 };
787 
788 // The prefetch operation is lowered in a way similar to the load operation
789 // except that the llvm.prefetch operation is used for replacement.
790 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
791  using Base::Base;
792 
794  matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
795  ConversionPatternRewriter &rewriter) const override {
796  auto type = prefetchOp.getMemRefType();
797  auto loc = prefetchOp.getLoc();
798 
799  Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
800  adaptor.getIndices(), rewriter);
801 
802  // Replace with llvm.prefetch.
803  auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
804  auto isWrite = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type,
805  prefetchOp.getIsWrite());
806  auto localityHint = rewriter.create<LLVM::ConstantOp>(
807  loc, llvmI32Type, prefetchOp.getLocalityHint());
808  auto isData = rewriter.create<LLVM::ConstantOp>(
809  loc, llvmI32Type, prefetchOp.getIsDataCache());
810 
811  rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
812  localityHint, isData);
813  return success();
814  }
815 };
816 
817 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
819 
821  matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
822  ConversionPatternRewriter &rewriter) const override {
823  Location loc = op.getLoc();
824  Type operandType = op.getMemref().getType();
825  if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
826  UnrankedMemRefDescriptor desc(adaptor.getMemref());
827  rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
828  return success();
829  }
830  if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
831  rewriter.replaceOp(
832  op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
833  return success();
834  }
835  return failure();
836  }
837 };
838 
839 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
841 
842  LogicalResult match(memref::CastOp memRefCastOp) const override {
843  Type srcType = memRefCastOp.getOperand().getType();
844  Type dstType = memRefCastOp.getType();
845 
846  // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
847  // used for type erasure. For now they must preserve underlying element type
848  // and require source and result type to have the same rank. Therefore,
849  // perform a sanity check that the underlying structs are the same. Once op
850  // semantics are relaxed we can revisit.
851  if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
852  return success(typeConverter->convertType(srcType) ==
853  typeConverter->convertType(dstType));
854 
855  // At least one of the operands is unranked type
856  assert(srcType.isa<UnrankedMemRefType>() ||
857  dstType.isa<UnrankedMemRefType>());
858 
859  // Unranked to unranked cast is disallowed
860  return !(srcType.isa<UnrankedMemRefType>() &&
861  dstType.isa<UnrankedMemRefType>())
862  ? success()
863  : failure();
864  }
865 
866  void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
867  ConversionPatternRewriter &rewriter) const override {
868  auto srcType = memRefCastOp.getOperand().getType();
869  auto dstType = memRefCastOp.getType();
870  auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
871  auto loc = memRefCastOp.getLoc();
872 
873  // For ranked/ranked case, just keep the original descriptor.
874  if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
875  return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
876 
877  if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
878  // Casting ranked to unranked memref type
879  // Set the rank in the destination from the memref type
880  // Allocate space on the stack and copy the src memref descriptor
881  // Set the ptr in the destination to the stack space
882  auto srcMemRefType = srcType.cast<MemRefType>();
883  int64_t rank = srcMemRefType.getRank();
884  // ptr = AllocaOp sizeof(MemRefDescriptor)
885  auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
886  loc, adaptor.getSource(), rewriter);
887 
888  // voidptr = BitCastOp srcType* to void*
889  Value voidPtr;
890  if (getTypeConverter()->useOpaquePointers())
891  voidPtr = ptr;
892  else
893  voidPtr = rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr);
894 
895  // rank = ConstantOp srcRank
896  auto rankVal = rewriter.create<LLVM::ConstantOp>(
897  loc, getIndexType(), rewriter.getIndexAttr(rank));
898  // undef = UndefOp
899  UnrankedMemRefDescriptor memRefDesc =
900  UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
901  // d1 = InsertValueOp undef, rank, 0
902  memRefDesc.setRank(rewriter, loc, rankVal);
903  // d2 = InsertValueOp d1, voidptr, 1
904  memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
905  rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
906 
907  } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
908  // Casting from unranked type to ranked.
909  // The operation is assumed to be doing a correct cast. If the destination
910  // type mismatches the unranked the type, it is undefined behavior.
911  UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
912  // ptr = ExtractValueOp src, 1
913  auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
914  // castPtr = BitCastOp i8* to structTy*
915  Value castPtr;
916  if (getTypeConverter()->useOpaquePointers())
917  castPtr = ptr;
918  else
919  castPtr = rewriter.create<LLVM::BitcastOp>(
920  loc, LLVM::LLVMPointerType::get(targetStructType), ptr);
921 
922  // struct = LoadOp castPtr
923  auto loadOp =
924  rewriter.create<LLVM::LoadOp>(loc, targetStructType, castPtr);
925  rewriter.replaceOp(memRefCastOp, loadOp.getResult());
926  } else {
927  llvm_unreachable("Unsupported unranked memref to unranked memref cast");
928  }
929  }
930 };
931 
932 /// Pattern to lower a `memref.copy` to llvm.
933 ///
934 /// For memrefs with identity layouts, the copy is lowered to the llvm
935 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
936 /// to the generic `MemrefCopyFn`.
937 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
939 
941  lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
942  ConversionPatternRewriter &rewriter) const {
943  auto loc = op.getLoc();
944  auto srcType = op.getSource().getType().dyn_cast<MemRefType>();
945 
946  MemRefDescriptor srcDesc(adaptor.getSource());
947 
948  // Compute number of elements.
949  Value numElements = rewriter.create<LLVM::ConstantOp>(
950  loc, getIndexType(), rewriter.getIndexAttr(1));
951  for (int pos = 0; pos < srcType.getRank(); ++pos) {
952  auto size = srcDesc.size(rewriter, loc, pos);
953  numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
954  }
955 
956  // Get element size.
957  auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
958  // Compute total.
959  Value totalSize =
960  rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
961 
962  Type elementType = typeConverter->convertType(srcType.getElementType());
963 
964  Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
965  Value srcOffset = srcDesc.offset(rewriter, loc);
966  Value srcPtr = rewriter.create<LLVM::GEPOp>(
967  loc, srcBasePtr.getType(), elementType, srcBasePtr, srcOffset);
968  MemRefDescriptor targetDesc(adaptor.getTarget());
969  Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
970  Value targetOffset = targetDesc.offset(rewriter, loc);
971  Value targetPtr = rewriter.create<LLVM::GEPOp>(
972  loc, targetBasePtr.getType(), elementType, targetBasePtr, targetOffset);
973  Value isVolatile =
974  rewriter.create<LLVM::ConstantOp>(loc, rewriter.getBoolAttr(false));
975  rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
976  isVolatile);
977  rewriter.eraseOp(op);
978 
979  return success();
980  }
981 
983  lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
984  ConversionPatternRewriter &rewriter) const {
985  auto loc = op.getLoc();
986  auto srcType = op.getSource().getType().cast<BaseMemRefType>();
987  auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
988 
989  // First make sure we have an unranked memref descriptor representation.
990  auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
991  auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
992  type.getRank());
993  auto *typeConverter = getTypeConverter();
994  auto ptr =
995  typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
996 
997  Value voidPtr;
998  if (getTypeConverter()->useOpaquePointers())
999  voidPtr = ptr;
1000  else
1001  voidPtr = rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr);
1002 
1003  auto unrankedType =
1004  UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
1005  return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter,
1006  unrankedType,
1007  ValueRange{rank, voidPtr});
1008  };
1009 
1010  // Save stack position before promoting descriptors
1011  auto stackSaveOp =
1012  rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
1013 
1014  Value unrankedSource = srcType.hasRank()
1015  ? makeUnranked(adaptor.getSource(), srcType)
1016  : adaptor.getSource();
1017  Value unrankedTarget = targetType.hasRank()
1018  ? makeUnranked(adaptor.getTarget(), targetType)
1019  : adaptor.getTarget();
1020 
1021  // Now promote the unranked descriptors to the stack.
1022  auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1023  rewriter.getIndexAttr(1));
1024  auto promote = [&](Value desc) {
1025  Type ptrType = getTypeConverter()->getPointerType(desc.getType());
1026  auto allocated =
1027  rewriter.create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one);
1028  rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
1029  return allocated;
1030  };
1031 
1032  auto sourcePtr = promote(unrankedSource);
1033  auto targetPtr = promote(unrankedTarget);
1034 
1035  unsigned typeSize =
1036  mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType());
1037  auto elemSize = rewriter.create<LLVM::ConstantOp>(
1038  loc, getIndexType(), rewriter.getIndexAttr(typeSize));
1039  auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
1040  op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
1041  rewriter.create<LLVM::CallOp>(loc, copyFn,
1042  ValueRange{elemSize, sourcePtr, targetPtr});
1043 
1044  // Restore stack used for descriptors
1045  rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
1046 
1047  rewriter.eraseOp(op);
1048 
1049  return success();
1050  }
1051 
1053  matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
1054  ConversionPatternRewriter &rewriter) const override {
1055  auto srcType = op.getSource().getType().cast<BaseMemRefType>();
1056  auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
1057 
1058  auto isStaticShapeAndContiguousRowMajor = [](MemRefType type) {
1059  if (!type.hasStaticShape())
1060  return false;
1061 
1062  SmallVector<int64_t> strides;
1063  int64_t offset;
1064  if (failed(getStridesAndOffset(type, strides, offset)))
1065  return false;
1066 
1067  int64_t runningStride = 1;
1068  for (unsigned i = strides.size(); i > 0; --i) {
1069  if (strides[i - 1] != runningStride)
1070  return false;
1071  runningStride *= type.getDimSize(i - 1);
1072  }
1073  return true;
1074  };
1075 
1076  auto isContiguousMemrefType = [&](BaseMemRefType type) {
1077  auto memrefType = type.dyn_cast<mlir::MemRefType>();
1078  // We can use memcpy for memrefs if they have an identity layout or are
1079  // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
1080  // special case handled by memrefCopy.
1081  return memrefType &&
1082  (memrefType.getLayout().isIdentity() ||
1083  (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
1084  isStaticShapeAndContiguousRowMajor(memrefType)));
1085  };
1086 
1087  if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
1088  return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
1089 
1090  return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
1091  }
1092 };
1093 
1094 struct MemorySpaceCastOpLowering
1095  : public ConvertOpToLLVMPattern<memref::MemorySpaceCastOp> {
1096  using ConvertOpToLLVMPattern<
1097  memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern;
1098 
1100  matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
1101  ConversionPatternRewriter &rewriter) const override {
1102  Location loc = op.getLoc();
1103 
1104  Type resultType = op.getDest().getType();
1105  if (auto resultTypeR = resultType.dyn_cast<MemRefType>()) {
1106  auto resultDescType =
1107  typeConverter->convertType(resultTypeR).cast<LLVM::LLVMStructType>();
1108  Type newPtrType = resultDescType.getBody()[0];
1109 
1110  SmallVector<Value> descVals;
1111  MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
1112  descVals);
1113  descVals[0] =
1114  rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]);
1115  descVals[1] =
1116  rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]);
1117  Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
1118  resultTypeR, descVals);
1119  rewriter.replaceOp(op, result);
1120  return success();
1121  }
1122  if (auto resultTypeU = resultType.dyn_cast<UnrankedMemRefType>()) {
1123  // Since the type converter won't be doing this for us, get the address
1124  // space.
1125  auto sourceType = op.getSource().getType().cast<UnrankedMemRefType>();
1126  FailureOr<unsigned> maybeSourceAddrSpace =
1127  getTypeConverter()->getMemRefAddressSpace(sourceType);
1128  if (failed(maybeSourceAddrSpace))
1129  return rewriter.notifyMatchFailure(loc,
1130  "non-integer source address space");
1131  unsigned sourceAddrSpace = *maybeSourceAddrSpace;
1132  FailureOr<unsigned> maybeResultAddrSpace =
1133  getTypeConverter()->getMemRefAddressSpace(resultTypeU);
1134  if (failed(maybeResultAddrSpace))
1135  return rewriter.notifyMatchFailure(loc,
1136  "non-integer result address space");
1137  unsigned resultAddrSpace = *maybeResultAddrSpace;
1138 
1139  UnrankedMemRefDescriptor sourceDesc(adaptor.getSource());
1140  Value rank = sourceDesc.rank(rewriter, loc);
1141  Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
1142 
1143  // Create and allocate storage for new memref descriptor.
1144  auto result = UnrankedMemRefDescriptor::undef(
1145  rewriter, loc, typeConverter->convertType(resultTypeU));
1146  result.setRank(rewriter, loc, rank);
1147  SmallVector<Value, 1> sizes;
1148  UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1149  result, resultAddrSpace, sizes);
1150  Value resultUnderlyingSize = sizes.front();
1151  Value resultUnderlyingDesc = rewriter.create<LLVM::AllocaOp>(
1152  loc, getVoidPtrType(), rewriter.getI8Type(), resultUnderlyingSize);
1153  result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
1154 
1155  // Copy pointers, performing address space casts.
1156  Type llvmElementType =
1157  typeConverter->convertType(sourceType.getElementType());
1158  LLVM::LLVMPointerType sourceElemPtrType =
1159  getTypeConverter()->getPointerType(llvmElementType, sourceAddrSpace);
1160  auto resultElemPtrType =
1161  getTypeConverter()->getPointerType(llvmElementType, resultAddrSpace);
1162 
1163  Value allocatedPtr = sourceDesc.allocatedPtr(
1164  rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
1165  Value alignedPtr =
1166  sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
1167  sourceUnderlyingDesc, sourceElemPtrType);
1168  allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1169  loc, resultElemPtrType, allocatedPtr);
1170  alignedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1171  loc, resultElemPtrType, alignedPtr);
1172 
1173  result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
1174  resultElemPtrType, allocatedPtr);
1175  result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
1176  resultUnderlyingDesc, resultElemPtrType, alignedPtr);
1177 
1178  // Copy all the index-valued operands.
1179  Value sourceIndexVals =
1180  sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1181  sourceUnderlyingDesc, sourceElemPtrType);
1182  Value resultIndexVals =
1183  result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1184  resultUnderlyingDesc, resultElemPtrType);
1185 
1186  int64_t bytesToSkip =
1187  2 *
1188  ceilDiv(getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
1189  Value bytesToSkipConst = rewriter.create<LLVM::ConstantOp>(
1190  loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
1191  Value copySize = rewriter.create<LLVM::SubOp>(
1192  loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst);
1193  Type llvmBool = typeConverter->convertType(rewriter.getI1Type());
1194  Value nonVolatile = rewriter.create<LLVM::ConstantOp>(
1195  loc, llvmBool, rewriter.getBoolAttr(false));
1196  rewriter.create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals,
1197  copySize, nonVolatile);
1198 
1199  rewriter.replaceOp(op, ValueRange{result});
1200  return success();
1201  }
1202  return rewriter.notifyMatchFailure(loc, "unexpected memref type");
1203  }
1204 };
1205 
1206 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
1207 /// memref type. In unranked case, the fields are extracted from the underlying
1208 /// ranked descriptor.
1209 static void extractPointersAndOffset(Location loc,
1210  ConversionPatternRewriter &rewriter,
1211  LLVMTypeConverter &typeConverter,
1212  Value originalOperand,
1213  Value convertedOperand,
1214  Value *allocatedPtr, Value *alignedPtr,
1215  Value *offset = nullptr) {
1216  Type operandType = originalOperand.getType();
1217  if (operandType.isa<MemRefType>()) {
1218  MemRefDescriptor desc(convertedOperand);
1219  *allocatedPtr = desc.allocatedPtr(rewriter, loc);
1220  *alignedPtr = desc.alignedPtr(rewriter, loc);
1221  if (offset != nullptr)
1222  *offset = desc.offset(rewriter, loc);
1223  return;
1224  }
1225 
1226  // These will all cause assert()s on unconvertible types.
1227  unsigned memorySpace = *typeConverter.getMemRefAddressSpace(
1228  operandType.cast<UnrankedMemRefType>());
1229  Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
1230  Type llvmElementType = typeConverter.convertType(elementType);
1231  LLVM::LLVMPointerType elementPtrType =
1232  typeConverter.getPointerType(llvmElementType, memorySpace);
1233 
1234  // Extract pointer to the underlying ranked memref descriptor and cast it to
1235  // ElemType**.
1236  UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
1237  Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1238 
1240  rewriter, loc, underlyingDescPtr, elementPtrType);
1242  rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1243  if (offset != nullptr) {
1245  rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1246  }
1247 }
1248 
1249 struct MemRefReinterpretCastOpLowering
1250  : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1251  using ConvertOpToLLVMPattern<
1252  memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1253 
1255  matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1256  ConversionPatternRewriter &rewriter) const override {
1257  Type srcType = castOp.getSource().getType();
1258 
1259  Value descriptor;
1260  if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1261  adaptor, &descriptor)))
1262  return failure();
1263  rewriter.replaceOp(castOp, {descriptor});
1264  return success();
1265  }
1266 
1267 private:
1268  LogicalResult convertSourceMemRefToDescriptor(
1269  ConversionPatternRewriter &rewriter, Type srcType,
1270  memref::ReinterpretCastOp castOp,
1271  memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1272  MemRefType targetMemRefType =
1273  castOp.getResult().getType().cast<MemRefType>();
1274  auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
1275  .dyn_cast_or_null<LLVM::LLVMStructType>();
1276  if (!llvmTargetDescriptorTy)
1277  return failure();
1278 
1279  // Create descriptor.
1280  Location loc = castOp.getLoc();
1281  auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1282 
1283  // Set allocated and aligned pointers.
1284  Value allocatedPtr, alignedPtr;
1285  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1286  castOp.getSource(), adaptor.getSource(),
1287  &allocatedPtr, &alignedPtr);
1288  desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1289  desc.setAlignedPtr(rewriter, loc, alignedPtr);
1290 
1291  // Set offset.
1292  if (castOp.isDynamicOffset(0))
1293  desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1294  else
1295  desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1296 
1297  // Set sizes and strides.
1298  unsigned dynSizeId = 0;
1299  unsigned dynStrideId = 0;
1300  for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1301  if (castOp.isDynamicSize(i))
1302  desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1303  else
1304  desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1305 
1306  if (castOp.isDynamicStride(i))
1307  desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1308  else
1309  desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1310  }
1311  *descriptor = desc;
1312  return success();
1313  }
1314 };
1315 
1316 struct MemRefReshapeOpLowering
1317  : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1319 
1321  matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1322  ConversionPatternRewriter &rewriter) const override {
1323  Type srcType = reshapeOp.getSource().getType();
1324 
1325  Value descriptor;
1326  if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1327  adaptor, &descriptor)))
1328  return failure();
1329  rewriter.replaceOp(reshapeOp, {descriptor});
1330  return success();
1331  }
1332 
1333 private:
1335  convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1336  Type srcType, memref::ReshapeOp reshapeOp,
1337  memref::ReshapeOp::Adaptor adaptor,
1338  Value *descriptor) const {
1339  auto shapeMemRefType = reshapeOp.getShape().getType().cast<MemRefType>();
1340  if (shapeMemRefType.hasStaticShape()) {
1341  MemRefType targetMemRefType =
1342  reshapeOp.getResult().getType().cast<MemRefType>();
1343  auto llvmTargetDescriptorTy =
1344  typeConverter->convertType(targetMemRefType)
1345  .dyn_cast_or_null<LLVM::LLVMStructType>();
1346  if (!llvmTargetDescriptorTy)
1347  return failure();
1348 
1349  // Create descriptor.
1350  Location loc = reshapeOp.getLoc();
1351  auto desc =
1352  MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1353 
1354  // Set allocated and aligned pointers.
1355  Value allocatedPtr, alignedPtr;
1356  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1357  reshapeOp.getSource(), adaptor.getSource(),
1358  &allocatedPtr, &alignedPtr);
1359  desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1360  desc.setAlignedPtr(rewriter, loc, alignedPtr);
1361 
1362  // Extract the offset and strides from the type.
1363  int64_t offset;
1364  SmallVector<int64_t> strides;
1365  if (failed(getStridesAndOffset(targetMemRefType, strides, offset)))
1366  return rewriter.notifyMatchFailure(
1367  reshapeOp, "failed to get stride and offset exprs");
1368 
1369  if (!isStaticStrideOrOffset(offset))
1370  return rewriter.notifyMatchFailure(reshapeOp,
1371  "dynamic offset is unsupported");
1372 
1373  desc.setConstantOffset(rewriter, loc, offset);
1374 
1375  assert(targetMemRefType.getLayout().isIdentity() &&
1376  "Identity layout map is a precondition of a valid reshape op");
1377 
1378  Value stride = nullptr;
1379  int64_t targetRank = targetMemRefType.getRank();
1380  for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1381  if (!ShapedType::isDynamic(strides[i])) {
1382  // If the stride for this dimension is dynamic, then use the product
1383  // of the sizes of the inner dimensions.
1384  stride = createIndexConstant(rewriter, loc, strides[i]);
1385  } else if (!stride) {
1386  // `stride` is null only in the first iteration of the loop. However,
1387  // since the target memref has an identity layout, we can safely set
1388  // the innermost stride to 1.
1389  stride = createIndexConstant(rewriter, loc, 1);
1390  }
1391 
1392  Value dimSize;
1393  int64_t size = targetMemRefType.getDimSize(i);
1394  // If the size of this dimension is dynamic, then load it at runtime
1395  // from the shape operand.
1396  if (!ShapedType::isDynamic(size)) {
1397  dimSize = createIndexConstant(rewriter, loc, size);
1398  } else {
1399  Value shapeOp = reshapeOp.getShape();
1400  Value index = createIndexConstant(rewriter, loc, i);
1401  dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
1402  Type indexType = getIndexType();
1403  if (dimSize.getType() != indexType)
1404  dimSize = typeConverter->materializeTargetConversion(
1405  rewriter, loc, indexType, dimSize);
1406  assert(dimSize && "Invalid memref element type");
1407  }
1408 
1409  desc.setSize(rewriter, loc, i, dimSize);
1410  desc.setStride(rewriter, loc, i, stride);
1411 
1412  // Prepare the stride value for the next dimension.
1413  stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
1414  }
1415 
1416  *descriptor = desc;
1417  return success();
1418  }
1419 
1420  // The shape is a rank-1 tensor with unknown length.
1421  Location loc = reshapeOp.getLoc();
1422  MemRefDescriptor shapeDesc(adaptor.getShape());
1423  Value resultRank = shapeDesc.size(rewriter, loc, 0);
1424 
1425  // Extract address space and element type.
1426  auto targetType =
1427  reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
1428  unsigned addressSpace =
1429  *getTypeConverter()->getMemRefAddressSpace(targetType);
1430  Type elementType = targetType.getElementType();
1431 
1432  // Create the unranked memref descriptor that holds the ranked one. The
1433  // inner descriptor is allocated on stack.
1434  auto targetDesc = UnrankedMemRefDescriptor::undef(
1435  rewriter, loc, typeConverter->convertType(targetType));
1436  targetDesc.setRank(rewriter, loc, resultRank);
1437  SmallVector<Value, 4> sizes;
1438  UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1439  targetDesc, addressSpace, sizes);
1440  Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1441  loc, getVoidPtrType(), IntegerType::get(getContext(), 8),
1442  sizes.front());
1443  targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1444 
1445  // Extract pointers and offset from the source memref.
1446  Value allocatedPtr, alignedPtr, offset;
1447  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1448  reshapeOp.getSource(), adaptor.getSource(),
1449  &allocatedPtr, &alignedPtr, &offset);
1450 
1451  // Set pointers and offset.
1452  Type llvmElementType = typeConverter->convertType(elementType);
1453  LLVM::LLVMPointerType elementPtrType =
1454  getTypeConverter()->getPointerType(llvmElementType, addressSpace);
1455 
1456  UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1457  elementPtrType, allocatedPtr);
1458  UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1459  underlyingDescPtr, elementPtrType,
1460  alignedPtr);
1461  UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1462  underlyingDescPtr, elementPtrType,
1463  offset);
1464 
1465  // Use the offset pointer as base for further addressing. Copy over the new
1466  // shape and compute strides. For this, we create a loop from rank-1 to 0.
1467  Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1468  rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
1469  Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1470  rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1471  Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1472  Value oneIndex = createIndexConstant(rewriter, loc, 1);
1473  Value resultRankMinusOne =
1474  rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1475 
1476  Block *initBlock = rewriter.getInsertionBlock();
1477  Type indexType = getTypeConverter()->getIndexType();
1478  Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1479 
1480  Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1481  {indexType, indexType}, {loc, loc});
1482 
1483  // Move the remaining initBlock ops to condBlock.
1484  Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1485  rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1486 
1487  rewriter.setInsertionPointToEnd(initBlock);
1488  rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1489  condBlock);
1490  rewriter.setInsertionPointToStart(condBlock);
1491  Value indexArg = condBlock->getArgument(0);
1492  Value strideArg = condBlock->getArgument(1);
1493 
1494  Value zeroIndex = createIndexConstant(rewriter, loc, 0);
1495  Value pred = rewriter.create<LLVM::ICmpOp>(
1496  loc, IntegerType::get(rewriter.getContext(), 1),
1497  LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1498 
1499  Block *bodyBlock =
1500  rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1501  rewriter.setInsertionPointToStart(bodyBlock);
1502 
1503  // Copy size from shape to descriptor.
1504  Type llvmIndexPtrType = getTypeConverter()->getPointerType(indexType);
1505  Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1506  loc, llvmIndexPtrType,
1507  typeConverter->convertType(shapeMemRefType.getElementType()),
1508  shapeOperandPtr, indexArg);
1509  Value size = rewriter.create<LLVM::LoadOp>(loc, indexType, sizeLoadGep);
1510  UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1511  targetSizesBase, indexArg, size);
1512 
1513  // Write stride value and compute next one.
1514  UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1515  targetStridesBase, indexArg, strideArg);
1516  Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1517 
1518  // Decrement loop counter and branch back.
1519  Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1520  rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1521  condBlock);
1522 
1523  Block *remainder =
1524  rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1525 
1526  // Hook up the cond exit to the remainder.
1527  rewriter.setInsertionPointToEnd(condBlock);
1528  rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, std::nullopt,
1529  remainder, std::nullopt);
1530 
1531  // Reset position to beginning of new remainder block.
1532  rewriter.setInsertionPointToStart(remainder);
1533 
1534  *descriptor = targetDesc;
1535  return success();
1536  }
1537 };
1538 
1539 /// RessociatingReshapeOp must be expanded before we reach this stage.
1540 /// Report that information.
1541 template <typename ReshapeOp>
1542 class ReassociatingReshapeOpConversion
1543  : public ConvertOpToLLVMPattern<ReshapeOp> {
1544 public:
1546  using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1547 
1549  matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1550  ConversionPatternRewriter &rewriter) const override {
1551  return rewriter.notifyMatchFailure(
1552  reshapeOp,
1553  "reassociation operations should have been expanded beforehand");
1554  }
1555 };
1556 
1557 /// Subviews must be expanded before we reach this stage.
1558 /// Report that information.
1559 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1561 
1563  matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1564  ConversionPatternRewriter &rewriter) const override {
1565  return rewriter.notifyMatchFailure(
1566  subViewOp, "subview operations should have been expanded beforehand");
1567  }
1568 };
1569 
1570 /// Conversion pattern that transforms a transpose op into:
1571 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
1572 /// 2. A load of the ViewDescriptor from the pointer allocated in 1.
1573 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1574 /// and stride. Size and stride are permutations of the original values.
1575 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1576 /// The transpose op is replaced by the alloca'ed pointer.
1577 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1578 public:
1580 
1582  matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1583  ConversionPatternRewriter &rewriter) const override {
1584  auto loc = transposeOp.getLoc();
1585  MemRefDescriptor viewMemRef(adaptor.getIn());
1586 
1587  // No permutation, early exit.
1588  if (transposeOp.getPermutation().isIdentity())
1589  return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1590 
1591  auto targetMemRef = MemRefDescriptor::undef(
1592  rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
1593 
1594  // Copy the base and aligned pointers from the old descriptor to the new
1595  // one.
1596  targetMemRef.setAllocatedPtr(rewriter, loc,
1597  viewMemRef.allocatedPtr(rewriter, loc));
1598  targetMemRef.setAlignedPtr(rewriter, loc,
1599  viewMemRef.alignedPtr(rewriter, loc));
1600 
1601  // Copy the offset pointer from the old descriptor to the new one.
1602  targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1603 
1604  // Iterate over the dimensions and apply size/stride permutation.
1605  for (const auto &en :
1606  llvm::enumerate(transposeOp.getPermutation().getResults())) {
1607  int sourcePos = en.index();
1608  int targetPos = en.value().cast<AffineDimExpr>().getPosition();
1609  targetMemRef.setSize(rewriter, loc, targetPos,
1610  viewMemRef.size(rewriter, loc, sourcePos));
1611  targetMemRef.setStride(rewriter, loc, targetPos,
1612  viewMemRef.stride(rewriter, loc, sourcePos));
1613  }
1614 
1615  rewriter.replaceOp(transposeOp, {targetMemRef});
1616  return success();
1617  }
1618 };
1619 
1620 /// Conversion pattern that transforms an op into:
1621 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor
1622 /// 2. Updates to the descriptor to introduce the data ptr, offset, size
1623 /// and stride.
1624 /// The view op is replaced by the descriptor.
1625 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1627 
1628  // Build and return the value for the idx^th shape dimension, either by
1629  // returning the constant shape dimension or counting the proper dynamic size.
1630  Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1631  ArrayRef<int64_t> shape, ValueRange dynamicSizes,
1632  unsigned idx) const {
1633  assert(idx < shape.size());
1634  if (!ShapedType::isDynamic(shape[idx]))
1635  return createIndexConstant(rewriter, loc, shape[idx]);
1636  // Count the number of dynamic dims in range [0, idx]
1637  unsigned nDynamic =
1638  llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1639  return dynamicSizes[nDynamic];
1640  }
1641 
1642  // Build and return the idx^th stride, either by returning the constant stride
1643  // or by computing the dynamic stride from the current `runningStride` and
1644  // `nextSize`. The caller should keep a running stride and update it with the
1645  // result returned by this function.
1646  Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1647  ArrayRef<int64_t> strides, Value nextSize,
1648  Value runningStride, unsigned idx) const {
1649  assert(idx < strides.size());
1650  if (!ShapedType::isDynamic(strides[idx]))
1651  return createIndexConstant(rewriter, loc, strides[idx]);
1652  if (nextSize)
1653  return runningStride
1654  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1655  : nextSize;
1656  assert(!runningStride);
1657  return createIndexConstant(rewriter, loc, 1);
1658  }
1659 
1661  matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1662  ConversionPatternRewriter &rewriter) const override {
1663  auto loc = viewOp.getLoc();
1664 
1665  auto viewMemRefType = viewOp.getType();
1666  auto targetElementTy =
1667  typeConverter->convertType(viewMemRefType.getElementType());
1668  auto targetDescTy = typeConverter->convertType(viewMemRefType);
1669  if (!targetDescTy || !targetElementTy ||
1670  !LLVM::isCompatibleType(targetElementTy) ||
1671  !LLVM::isCompatibleType(targetDescTy))
1672  return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1673  failure();
1674 
1675  int64_t offset;
1676  SmallVector<int64_t, 4> strides;
1677  auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1678  if (failed(successStrides))
1679  return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1680  assert(offset == 0 && "expected offset to be 0");
1681 
1682  // Target memref must be contiguous in memory (innermost stride is 1), or
1683  // empty (special case when at least one of the memref dimensions is 0).
1684  if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1685  return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1686  failure();
1687 
1688  // Create the descriptor.
1689  MemRefDescriptor sourceMemRef(adaptor.getSource());
1690  auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1691 
1692  // Field 1: Copy the allocated pointer, used for malloc/free.
1693  Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1694  auto srcMemRefType = viewOp.getSource().getType().cast<MemRefType>();
1695  unsigned sourceMemorySpace =
1696  *getTypeConverter()->getMemRefAddressSpace(srcMemRefType);
1697  Value bitcastPtr;
1698  if (getTypeConverter()->useOpaquePointers())
1699  bitcastPtr = allocatedPtr;
1700  else
1701  bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1702  loc, LLVM::LLVMPointerType::get(targetElementTy, sourceMemorySpace),
1703  allocatedPtr);
1704 
1705  targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1706 
1707  // Field 2: Copy the actual aligned pointer to payload.
1708  Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1709  alignedPtr = rewriter.create<LLVM::GEPOp>(
1710  loc, alignedPtr.getType(),
1711  typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
1712  adaptor.getByteShift());
1713 
1714  if (getTypeConverter()->useOpaquePointers()) {
1715  bitcastPtr = alignedPtr;
1716  } else {
1717  bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1718  loc, LLVM::LLVMPointerType::get(targetElementTy, sourceMemorySpace),
1719  alignedPtr);
1720  }
1721 
1722  targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1723 
1724  // Field 3: The offset in the resulting type must be 0. This is because of
1725  // the type change: an offset on srcType* may not be expressible as an
1726  // offset on dstType*.
1727  targetMemRef.setOffset(rewriter, loc,
1728  createIndexConstant(rewriter, loc, offset));
1729 
1730  // Early exit for 0-D corner case.
1731  if (viewMemRefType.getRank() == 0)
1732  return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1733 
1734  // Fields 4 and 5: Update sizes and strides.
1735  Value stride = nullptr, nextSize = nullptr;
1736  for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1737  // Update size.
1738  Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1739  adaptor.getSizes(), i);
1740  targetMemRef.setSize(rewriter, loc, i, size);
1741  // Update stride.
1742  stride = getStride(rewriter, loc, strides, nextSize, stride, i);
1743  targetMemRef.setStride(rewriter, loc, i, stride);
1744  nextSize = size;
1745  }
1746 
1747  rewriter.replaceOp(viewOp, {targetMemRef});
1748  return success();
1749  }
1750 };
1751 
1752 //===----------------------------------------------------------------------===//
1753 // AtomicRMWOpLowering
1754 //===----------------------------------------------------------------------===//
1755 
1756 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1757 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1758 static std::optional<LLVM::AtomicBinOp>
1759 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1760  switch (atomicOp.getKind()) {
1761  case arith::AtomicRMWKind::addf:
1762  return LLVM::AtomicBinOp::fadd;
1763  case arith::AtomicRMWKind::addi:
1764  return LLVM::AtomicBinOp::add;
1765  case arith::AtomicRMWKind::assign:
1766  return LLVM::AtomicBinOp::xchg;
1767  case arith::AtomicRMWKind::maxs:
1768  return LLVM::AtomicBinOp::max;
1769  case arith::AtomicRMWKind::maxu:
1770  return LLVM::AtomicBinOp::umax;
1771  case arith::AtomicRMWKind::mins:
1772  return LLVM::AtomicBinOp::min;
1773  case arith::AtomicRMWKind::minu:
1774  return LLVM::AtomicBinOp::umin;
1775  case arith::AtomicRMWKind::ori:
1776  return LLVM::AtomicBinOp::_or;
1777  case arith::AtomicRMWKind::andi:
1778  return LLVM::AtomicBinOp::_and;
1779  default:
1780  return std::nullopt;
1781  }
1782  llvm_unreachable("Invalid AtomicRMWKind");
1783 }
1784 
1785 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1786  using Base::Base;
1787 
1789  matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1790  ConversionPatternRewriter &rewriter) const override {
1791  if (failed(match(atomicOp)))
1792  return failure();
1793  auto maybeKind = matchSimpleAtomicOp(atomicOp);
1794  if (!maybeKind)
1795  return failure();
1796  auto memRefType = atomicOp.getMemRefType();
1797  auto dataPtr =
1798  getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
1799  adaptor.getIndices(), rewriter);
1800  rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1801  atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1802  LLVM::AtomicOrdering::acq_rel);
1803  return success();
1804  }
1805 };
1806 
1807 /// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index.
1808 class ConvertExtractAlignedPointerAsIndex
1809  : public ConvertOpToLLVMPattern<memref::ExtractAlignedPointerAsIndexOp> {
1810 public:
1811  using ConvertOpToLLVMPattern<
1812  memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern;
1813 
1815  matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1816  OpAdaptor adaptor,
1817  ConversionPatternRewriter &rewriter) const override {
1818  MemRefDescriptor desc(adaptor.getSource());
1819  rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
1820  extractOp, getTypeConverter()->getIndexType(),
1821  desc.alignedPtr(rewriter, extractOp->getLoc()));
1822  return success();
1823  }
1824 };
1825 
1826 /// Materialize the MemRef descriptor represented by the results of
1827 /// ExtractStridedMetadataOp.
1828 class ExtractStridedMetadataOpLowering
1829  : public ConvertOpToLLVMPattern<memref::ExtractStridedMetadataOp> {
1830 public:
1831  using ConvertOpToLLVMPattern<
1832  memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern;
1833 
1835  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1836  OpAdaptor adaptor,
1837  ConversionPatternRewriter &rewriter) const override {
1838 
1839  if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1840  return failure();
1841 
1842  // Create the descriptor.
1843  MemRefDescriptor sourceMemRef(adaptor.getSource());
1844  Location loc = extractStridedMetadataOp.getLoc();
1845  Value source = extractStridedMetadataOp.getSource();
1846 
1847  auto sourceMemRefType = source.getType().cast<MemRefType>();
1848  int64_t rank = sourceMemRefType.getRank();
1849  SmallVector<Value> results;
1850  results.reserve(2 + rank * 2);
1851 
1852  // Base buffer.
1853  Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc);
1854  Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
1856  rewriter, loc, *getTypeConverter(),
1857  extractStridedMetadataOp.getBaseBuffer().getType().cast<MemRefType>(),
1858  baseBuffer, alignedBuffer);
1859  results.push_back((Value)dstMemRef);
1860 
1861  // Offset.
1862  results.push_back(sourceMemRef.offset(rewriter, loc));
1863 
1864  // Sizes.
1865  for (unsigned i = 0; i < rank; ++i)
1866  results.push_back(sourceMemRef.size(rewriter, loc, i));
1867  // Strides.
1868  for (unsigned i = 0; i < rank; ++i)
1869  results.push_back(sourceMemRef.stride(rewriter, loc, i));
1870 
1871  rewriter.replaceOp(extractStridedMetadataOp, results);
1872  return success();
1873  }
1874 };
1875 
1876 } // namespace
1877 
1879  LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1880  // clang-format off
1881  patterns.add<
1882  AllocaOpLowering,
1883  AllocaScopeOpLowering,
1884  AtomicRMWOpLowering,
1885  AssumeAlignmentOpLowering,
1886  ConvertExtractAlignedPointerAsIndex,
1887  DimOpLowering,
1888  ExtractStridedMetadataOpLowering,
1889  GenericAtomicRMWOpLowering,
1890  GlobalMemrefOpLowering,
1891  GetGlobalMemrefOpLowering,
1892  LoadOpLowering,
1893  MemRefCastOpLowering,
1894  MemRefCopyOpLowering,
1895  MemorySpaceCastOpLowering,
1896  MemRefReinterpretCastOpLowering,
1897  MemRefReshapeOpLowering,
1898  PrefetchOpLowering,
1899  RankOpLowering,
1900  ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1901  ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1902  StoreOpLowering,
1903  SubViewOpLowering,
1905  ViewOpLowering>(converter);
1906  // clang-format on
1907  auto allocLowering = converter.getOptions().allocLowering;
1908  if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
1909  patterns.add<AlignedAllocOpLowering, AlignedReallocOpLowering,
1910  DeallocOpLowering>(converter);
1911  else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
1912  patterns.add<AllocOpLowering, ReallocOpLowering, DeallocOpLowering>(
1913  converter);
1914 }
1915 
1916 namespace {
1917 struct FinalizeMemRefToLLVMConversionPass
1918  : public impl::FinalizeMemRefToLLVMConversionPassBase<
1919  FinalizeMemRefToLLVMConversionPass> {
1920  using FinalizeMemRefToLLVMConversionPassBase::
1921  FinalizeMemRefToLLVMConversionPassBase;
1922 
1923  void runOnOperation() override {
1924  Operation *op = getOperation();
1925  const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1926  LowerToLLVMOptions options(&getContext(),
1927  dataLayoutAnalysis.getAtOrAbove(op));
1928  options.allocLowering =
1929  (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
1930  : LowerToLLVMOptions::AllocLowering::Malloc);
1931 
1932  options.useGenericFunctions = useGenericFunctions;
1933  options.useOpaquePointers = useOpaquePointers;
1934 
1935  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
1936  options.overrideIndexBitwidth(indexBitwidth);
1937 
1938  LLVMTypeConverter typeConverter(&getContext(), options,
1939  &dataLayoutAnalysis);
1940  RewritePatternSet patterns(&getContext());
1941  populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
1942  LLVMConversionTarget target(getContext());
1943  target.addLegalOp<func::FuncOp>();
1944  if (failed(applyPartialConversion(op, target, std::move(patterns))))
1945  signalPassFailure();
1946  }
1947 };
1948 } // namespace
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
static spirv::ScalarType getIndexType(MLIRContext *ctx, const SPIRVConversionOptions &options)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:698
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U cast() const
Definition: Attributes.h:176
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:116
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this memref type.
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:129
BlockArgument getArgument(unsigned i)
Definition: Block.h:118
Operation & back()
Definition: Block.h:141
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:232
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:141
BlockArgListType getArguments()
Definition: Block.h:76
Operation & front()
Definition: Block.h:142
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:198
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:121
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:84
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:113
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:70
IntegerType getI8Type()
Definition: Builders.cpp:76
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:135
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
unsigned getTypeSize(Type t) const
Returns the size of the given type in the current scope.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:33
const LowerToLLVMOptions & getOptions() const
Definition: TypeConverter.h:84
LLVM::LLVMPointerType getPointerType(Type elementType, unsigned addressSpace=0)
Creates an LLVM pointer type with the given element type and address space.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type)
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
bool useOpaquePointers() const
Returns true if using opaque pointers was enabled in the lowering options.
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
Definition: LLVMTypes.h:106
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Options to control the LLVM lowering.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
static Value pack(OpBuilder &builder, Location loc, LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:329
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:426
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:520
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:412
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:379
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:417
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:405
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:432
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:423
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
Value getOperand(unsigned idx)
Definition: Operation.h:329
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:207
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Definition: Operation.h:421
result_range getResults()
Definition: Operation.h:394
BlockListType::iterator iterator
Definition: Region.h:52
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:482
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:321
U dyn_cast() const
Definition: Types.h:311
bool isa() const
Definition: Types.h:301
static Value offset(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the offset from the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
TODO: The following accessors don't take alignment rules between elements of the descriptor struct in...
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static void setAlignedPtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
static Value strideBasePtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
static void setSize(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static Value pack(OpBuilder &builder, Location loc, LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static Value alignedPtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the aligned pointer from the descriptor.
static void setStride(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
static void computeSizes(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, ArrayRef< UnrankedMemRefDescriptor > values, ArrayRef< unsigned > addressSpaces, SmallVectorImpl< Value > &sizes)
Builds IR computing the sizes in bytes (suitable for opaque allocation) and appends the corresponding...
static Value sizeBasePtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the pointer to the first element of the size array.
static void setOffset(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:370
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType, Type unrankedDescriptorType)
LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp, bool opaquePointers)
LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(ModuleOp moduleOp, bool opaquePointers)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:824
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:223
void promote(PatternRewriter &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:555
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
Definition: MathExtras.h:23
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void populateFinalizeMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
Lowering for AllocOp and AllocaOp.
Lowering for memory allocation ops.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26