MLIR  16.0.0git
MemRefToLLVM.cpp
Go to the documentation of this file.
1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "../PassDetail.h"
21 #include "mlir/IR/AffineMap.h"
23 #include "llvm/ADT/SmallBitVector.h"
24 
25 using namespace mlir;
26 
27 namespace {
28 
29 bool isStaticStrideOrOffset(int64_t strideOrOffset) {
30  return !ShapedType::isDynamicStrideOrOffset(strideOrOffset);
31 }
32 
33 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
34  AllocOpLowering(LLVMTypeConverter &converter)
35  : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
36  converter) {}
37 
38  LLVM::LLVMFuncOp getAllocFn(ModuleOp module) const {
39  bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions;
40 
41  if (useGenericFn)
42  return LLVM::lookupOrCreateGenericAllocFn(module, getIndexType());
43 
44  return LLVM::lookupOrCreateMallocFn(module, getIndexType());
45  }
46 
47  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
48  Location loc, Value sizeBytes,
49  Operation *op) const override {
50  // Heap allocations.
51  memref::AllocOp allocOp = cast<memref::AllocOp>(op);
52  MemRefType memRefType = allocOp.getType();
53 
54  Value alignment;
55  if (auto alignmentAttr = allocOp.getAlignment()) {
56  alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
57  } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
58  // In the case where no alignment is specified, we may want to override
59  // `malloc's` behavior. `malloc` typically aligns at the size of the
60  // biggest scalar on a target HW. For non-scalars, use the natural
61  // alignment of the LLVM type given by the LLVM DataLayout.
62  alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
63  }
64 
65  if (alignment) {
66  // Adjust the allocation size to consider alignment.
67  sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
68  }
69 
70  // Allocate the underlying buffer and store a pointer to it in the MemRef
71  // descriptor.
72  Type elementPtrType = this->getElementPtrType(memRefType);
73  auto allocFuncOp = getAllocFn(allocOp->getParentOfType<ModuleOp>());
74  auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
75  Value allocatedPtr = rewriter.create<LLVM::BitcastOp>(loc, elementPtrType,
76  results.getResult());
77 
78  Value alignedPtr = allocatedPtr;
79  if (alignment) {
80  // Compute the aligned type pointer.
81  Value allocatedInt =
82  rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
83  Value alignmentInt =
84  createAligned(rewriter, loc, allocatedInt, alignment);
85  alignedPtr =
86  rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
87  }
88 
89  return std::make_tuple(allocatedPtr, alignedPtr);
90  }
91 };
92 
93 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
94  AlignedAllocOpLowering(LLVMTypeConverter &converter)
95  : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
96  converter) {}
97 
98  /// Returns the memref's element size in bytes using the data layout active at
99  /// `op`.
100  // TODO: there are other places where this is used. Expose publicly?
101  unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const {
102  const DataLayout *layout = &defaultLayout;
103  if (const DataLayoutAnalysis *analysis =
104  getTypeConverter()->getDataLayoutAnalysis()) {
105  layout = &analysis->getAbove(op);
106  }
107  Type elementType = memRefType.getElementType();
108  if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
109  return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
110  *layout);
111  if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
112  return getTypeConverter()->getUnrankedMemRefDescriptorSize(
113  memRefElementType, *layout);
114  return layout->getTypeSize(elementType);
115  }
116 
117  /// Returns true if the memref size in bytes is known to be a multiple of
118  /// factor assuming the data layout active at `op`.
119  bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
120  Operation *op) const {
121  uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op);
122  for (unsigned i = 0, e = type.getRank(); i < e; i++) {
123  if (ShapedType::isDynamic(type.getDimSize(i)))
124  continue;
125  sizeDivisor = sizeDivisor * type.getDimSize(i);
126  }
127  return sizeDivisor % factor == 0;
128  }
129 
130  /// Returns the alignment to be used for the allocation call itself.
131  /// aligned_alloc requires the allocation size to be a power of two, and the
132  /// allocation size to be a multiple of alignment,
133  int64_t getAllocationAlignment(memref::AllocOp allocOp) const {
134  if (Optional<uint64_t> alignment = allocOp.getAlignment())
135  return *alignment;
136 
137  // Whenever we don't have alignment set, we will use an alignment
138  // consistent with the element type; since the allocation size has to be a
139  // power of two, we will bump to the next power of two if it already isn't.
140  auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp);
141  return std::max(kMinAlignedAllocAlignment,
142  llvm::PowerOf2Ceil(eltSizeBytes));
143  }
144 
145  LLVM::LLVMFuncOp getAllocFn(ModuleOp module) const {
146  bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions;
147 
148  if (useGenericFn)
149  return LLVM::lookupOrCreateGenericAlignedAllocFn(module, getIndexType());
150 
151  return LLVM::lookupOrCreateAlignedAllocFn(module, getIndexType());
152  }
153 
154  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
155  Location loc, Value sizeBytes,
156  Operation *op) const override {
157  // Heap allocations.
158  memref::AllocOp allocOp = cast<memref::AllocOp>(op);
159  MemRefType memRefType = allocOp.getType();
160  int64_t alignment = getAllocationAlignment(allocOp);
161  Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
162 
163  // aligned_alloc requires size to be a multiple of alignment; we will pad
164  // the size to the next multiple if necessary.
165  if (!isMemRefSizeMultipleOf(memRefType, alignment, op))
166  sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
167 
168  Type elementPtrType = this->getElementPtrType(memRefType);
169  auto allocFuncOp = getAllocFn(allocOp->getParentOfType<ModuleOp>());
170  auto results = rewriter.create<LLVM::CallOp>(
171  loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
172  Value allocatedPtr = rewriter.create<LLVM::BitcastOp>(loc, elementPtrType,
173  results.getResult());
174 
175  return std::make_tuple(allocatedPtr, allocatedPtr);
176  }
177 
178  /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
179  static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
180 
181  /// Default layout to use in absence of the corresponding analysis.
182  DataLayout defaultLayout;
183 };
184 
185 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
186  AllocaOpLowering(LLVMTypeConverter &converter)
187  : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
188  converter) {}
189 
190  /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
191  /// is set to null for stack allocations. `accessAlignment` is set if
192  /// alignment is needed post allocation (for eg. in conjunction with malloc).
193  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
194  Location loc, Value sizeBytes,
195  Operation *op) const override {
196 
197  // With alloca, one gets a pointer to the element type right away.
198  // For stack allocations.
199  auto allocaOp = cast<memref::AllocaOp>(op);
200  auto elementPtrType = this->getElementPtrType(allocaOp.getType());
201 
202  auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
203  loc, elementPtrType, sizeBytes, allocaOp.getAlignment().value_or(0));
204 
205  return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
206  }
207 };
208 
209 struct AllocaScopeOpLowering
210  : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
212 
214  matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
215  ConversionPatternRewriter &rewriter) const override {
216  OpBuilder::InsertionGuard guard(rewriter);
217  Location loc = allocaScopeOp.getLoc();
218 
219  // Split the current block before the AllocaScopeOp to create the inlining
220  // point.
221  auto *currentBlock = rewriter.getInsertionBlock();
222  auto *remainingOpsBlock =
223  rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
224  Block *continueBlock;
225  if (allocaScopeOp.getNumResults() == 0) {
226  continueBlock = remainingOpsBlock;
227  } else {
228  continueBlock = rewriter.createBlock(
229  remainingOpsBlock, allocaScopeOp.getResultTypes(),
230  SmallVector<Location>(allocaScopeOp->getNumResults(),
231  allocaScopeOp.getLoc()));
232  rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
233  }
234 
235  // Inline body region.
236  Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
237  Block *afterBody = &allocaScopeOp.getBodyRegion().back();
238  rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
239 
240  // Save stack and then branch into the body of the region.
241  rewriter.setInsertionPointToEnd(currentBlock);
242  auto stackSaveOp =
243  rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
244  rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
245 
246  // Replace the alloca_scope return with a branch that jumps out of the body.
247  // Stack restore before leaving the body region.
248  rewriter.setInsertionPointToEnd(afterBody);
249  auto returnOp =
250  cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
251  auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
252  returnOp, returnOp.getResults(), continueBlock);
253 
254  // Insert stack restore before jumping out the body of the region.
255  rewriter.setInsertionPoint(branchOp);
256  rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
257 
258  // Replace the op with values return from the body region.
259  rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
260 
261  return success();
262  }
263 };
264 
265 struct AssumeAlignmentOpLowering
266  : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
268  memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
269 
271  matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
272  ConversionPatternRewriter &rewriter) const override {
273  Value memref = adaptor.getMemref();
274  unsigned alignment = op.getAlignment();
275  auto loc = op.getLoc();
276 
277  MemRefDescriptor memRefDescriptor(memref);
278  Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
279 
280  // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
281  // the asserted memref.alignedPtr isn't used anywhere else, as the real
282  // users like load/store/views always re-extract memref.alignedPtr as they
283  // get lowered.
284  //
285  // This relies on LLVM's CSE optimization (potentially after SROA), since
286  // after CSE all memref.alignedPtr instances get de-duplicated into the same
287  // pointer SSA value.
288  auto intPtrType =
289  getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
290  Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
291  Value mask =
292  createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
293  Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
294  rewriter.create<LLVM::AssumeOp>(
295  loc, rewriter.create<LLVM::ICmpOp>(
296  loc, LLVM::ICmpPredicate::eq,
297  rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
298 
299  rewriter.eraseOp(op);
300  return success();
301  }
302 };
303 
304 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
305 // The memref descriptor being an SSA value, there is no need to clean it up
306 // in any way.
307 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
309 
310  explicit DeallocOpLowering(LLVMTypeConverter &converter)
311  : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
312 
313  LLVM::LLVMFuncOp getFreeFn(ModuleOp module) const {
314  bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions;
315 
316  if (useGenericFn)
317  return LLVM::lookupOrCreateGenericFreeFn(module);
318 
319  return LLVM::lookupOrCreateFreeFn(module);
320  }
321 
323  matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
324  ConversionPatternRewriter &rewriter) const override {
325  // Insert the `free` declaration if it is not already present.
326  auto freeFunc = getFreeFn(op->getParentOfType<ModuleOp>());
327  MemRefDescriptor memref(adaptor.getMemref());
328  Value casted = rewriter.create<LLVM::BitcastOp>(
329  op.getLoc(), getVoidPtrType(),
330  memref.allocatedPtr(rewriter, op.getLoc()));
331  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc, casted);
332  return success();
333  }
334 };
335 
336 // A `dim` is converted to a constant for static sizes and to an access to the
337 // size stored in the memref descriptor for dynamic sizes.
338 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
340 
342  matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
343  ConversionPatternRewriter &rewriter) const override {
344  Type operandType = dimOp.getSource().getType();
345  if (operandType.isa<UnrankedMemRefType>()) {
346  rewriter.replaceOp(
347  dimOp, {extractSizeOfUnrankedMemRef(
348  operandType, dimOp, adaptor.getOperands(), rewriter)});
349 
350  return success();
351  }
352  if (operandType.isa<MemRefType>()) {
353  rewriter.replaceOp(
354  dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
355  adaptor.getOperands(), rewriter)});
356  return success();
357  }
358  llvm_unreachable("expected MemRefType or UnrankedMemRefType");
359  }
360 
361 private:
362  Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
363  OpAdaptor adaptor,
364  ConversionPatternRewriter &rewriter) const {
365  Location loc = dimOp.getLoc();
366 
367  auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
368  auto scalarMemRefType =
369  MemRefType::get({}, unrankedMemRefType.getElementType());
370  unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
371 
372  // Extract pointer to the underlying ranked descriptor and bitcast it to a
373  // memref<element_type> descriptor pointer to minimize the number of GEP
374  // operations.
375  UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
376  Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
377  Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
378  loc,
379  LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType),
380  addressSpace),
381  underlyingRankedDesc);
382 
383  // Get pointer to offset field of memref<element_type> descriptor.
384  Type indexPtrTy = LLVM::LLVMPointerType::get(
385  getTypeConverter()->getIndexType(), addressSpace);
386  Value offsetPtr = rewriter.create<LLVM::GEPOp>(
387  loc, indexPtrTy, scalarMemRefDescPtr, ArrayRef<LLVM::GEPArg>{0, 2});
388 
389  // The size value that we have to extract can be obtained using GEPop with
390  // `dimOp.index() + 1` index argument.
391  Value idxPlusOne = rewriter.create<LLVM::AddOp>(
392  loc, createIndexConstant(rewriter, loc, 1), adaptor.getIndex());
393  Value sizePtr =
394  rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr, idxPlusOne);
395  return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
396  }
397 
398  Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
399  if (Optional<int64_t> idx = dimOp.getConstantIndex())
400  return idx;
401 
402  if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
403  return constantOp.getValue()
404  .cast<IntegerAttr>()
405  .getValue()
406  .getSExtValue();
407 
408  return llvm::None;
409  }
410 
411  Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
412  OpAdaptor adaptor,
413  ConversionPatternRewriter &rewriter) const {
414  Location loc = dimOp.getLoc();
415 
416  // Take advantage if index is constant.
417  MemRefType memRefType = operandType.cast<MemRefType>();
418  if (Optional<int64_t> index = getConstantDimIndex(dimOp)) {
419  int64_t i = *index;
420  if (memRefType.isDynamicDim(i)) {
421  // extract dynamic size from the memref descriptor.
422  MemRefDescriptor descriptor(adaptor.getSource());
423  return descriptor.size(rewriter, loc, i);
424  }
425  // Use constant for static size.
426  int64_t dimSize = memRefType.getDimSize(i);
427  return createIndexConstant(rewriter, loc, dimSize);
428  }
429  Value index = adaptor.getIndex();
430  int64_t rank = memRefType.getRank();
431  MemRefDescriptor memrefDescriptor(adaptor.getSource());
432  return memrefDescriptor.size(rewriter, loc, index, rank);
433  }
434 };
435 
436 /// Common base for load and store operations on MemRefs. Restricts the match
437 /// to supported MemRef types. Provides functionality to emit code accessing a
438 /// specific element of the underlying data buffer.
439 template <typename Derived>
440 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
443  using Base = LoadStoreOpLowering<Derived>;
444 
445  LogicalResult match(Derived op) const override {
446  MemRefType type = op.getMemRefType();
447  return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
448  }
449 };
450 
451 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
452 /// retried until it succeeds in atomically storing a new value into memory.
453 ///
454 /// +---------------------------------+
455 /// | <code before the AtomicRMWOp> |
456 /// | <compute initial %loaded> |
457 /// | cf.br loop(%loaded) |
458 /// +---------------------------------+
459 /// |
460 /// -------| |
461 /// | v v
462 /// | +--------------------------------+
463 /// | | loop(%loaded): |
464 /// | | <body contents> |
465 /// | | %pair = cmpxchg |
466 /// | | %ok = %pair[0] |
467 /// | | %new = %pair[1] |
468 /// | | cf.cond_br %ok, end, loop(%new) |
469 /// | +--------------------------------+
470 /// | | |
471 /// |----------- |
472 /// v
473 /// +--------------------------------+
474 /// | end: |
475 /// | <code after the AtomicRMWOp> |
476 /// +--------------------------------+
477 ///
478 struct GenericAtomicRMWOpLowering
479  : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
480  using Base::Base;
481 
483  matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
484  ConversionPatternRewriter &rewriter) const override {
485  auto loc = atomicOp.getLoc();
486  Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
487 
488  // Split the block into initial, loop, and ending parts.
489  auto *initBlock = rewriter.getInsertionBlock();
490  auto *loopBlock = rewriter.createBlock(
491  initBlock->getParent(), std::next(Region::iterator(initBlock)),
492  valueType, loc);
493  auto *endBlock = rewriter.createBlock(
494  loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
495 
496  // Operations range to be moved to `endBlock`.
497  auto opsToMoveStart = atomicOp->getIterator();
498  auto opsToMoveEnd = initBlock->back().getIterator();
499 
500  // Compute the loaded value and branch to the loop block.
501  rewriter.setInsertionPointToEnd(initBlock);
502  auto memRefType = atomicOp.getMemref().getType().cast<MemRefType>();
503  auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
504  adaptor.getIndices(), rewriter);
505  Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
506  rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
507 
508  // Prepare the body of the loop block.
509  rewriter.setInsertionPointToStart(loopBlock);
510 
511  // Clone the GenericAtomicRMWOp region and extract the result.
512  auto loopArgument = loopBlock->getArgument(0);
513  BlockAndValueMapping mapping;
514  mapping.map(atomicOp.getCurrentValue(), loopArgument);
515  Block &entryBlock = atomicOp.body().front();
516  for (auto &nestedOp : entryBlock.without_terminator()) {
517  Operation *clone = rewriter.clone(nestedOp, mapping);
518  mapping.map(nestedOp.getResults(), clone->getResults());
519  }
520  Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
521 
522  // Prepare the epilog of the loop block.
523  // Append the cmpxchg op to the end of the loop block.
524  auto successOrdering = LLVM::AtomicOrdering::acq_rel;
525  auto failureOrdering = LLVM::AtomicOrdering::monotonic;
526  auto boolType = IntegerType::get(rewriter.getContext(), 1);
527  auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
528  {valueType, boolType});
529  auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
530  loc, pairType, dataPtr, loopArgument, result, successOrdering,
531  failureOrdering);
532  // Extract the %new_loaded and %ok values from the pair.
533  Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 0);
534  Value ok = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 1);
535 
536  // Conditionally branch to the end or back to the loop depending on %ok.
537  rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
538  loopBlock, newLoaded);
539 
540  rewriter.setInsertionPointToEnd(endBlock);
541  moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
542  std::next(opsToMoveEnd), rewriter);
543 
544  // The 'result' of the atomic_rmw op is the newly loaded value.
545  rewriter.replaceOp(atomicOp, {newLoaded});
546 
547  return success();
548  }
549 
550 private:
551  // Clones a segment of ops [start, end) and erases the original.
552  void moveOpsRange(ValueRange oldResult, ValueRange newResult,
553  Block::iterator start, Block::iterator end,
554  ConversionPatternRewriter &rewriter) const {
555  BlockAndValueMapping mapping;
556  mapping.map(oldResult, newResult);
557  SmallVector<Operation *, 2> opsToErase;
558  for (auto it = start; it != end; ++it) {
559  rewriter.clone(*it, mapping);
560  opsToErase.push_back(&*it);
561  }
562  for (auto *it : opsToErase)
563  rewriter.eraseOp(it);
564  }
565 };
566 
567 /// Returns the LLVM type of the global variable given the memref type `type`.
568 static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
569  LLVMTypeConverter &typeConverter) {
570  // LLVM type for a global memref will be a multi-dimension array. For
571  // declarations or uninitialized global memrefs, we can potentially flatten
572  // this to a 1D array. However, for memref.global's with an initial value,
573  // we do not intend to flatten the ElementsAttribute when going from std ->
574  // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
575  Type elementType = typeConverter.convertType(type.getElementType());
576  Type arrayTy = elementType;
577  // Shape has the outermost dim at index 0, so need to walk it backwards
578  for (int64_t dim : llvm::reverse(type.getShape()))
579  arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
580  return arrayTy;
581 }
582 
583 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
584 struct GlobalMemrefOpLowering
585  : public ConvertOpToLLVMPattern<memref::GlobalOp> {
587 
589  matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
590  ConversionPatternRewriter &rewriter) const override {
591  MemRefType type = global.getType();
592  if (!isConvertibleAndHasIdentityMaps(type))
593  return failure();
594 
595  Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
596 
597  LLVM::Linkage linkage =
598  global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
599 
600  Attribute initialValue = nullptr;
601  if (!global.isExternal() && !global.isUninitialized()) {
602  auto elementsAttr = global.getInitialValue()->cast<ElementsAttr>();
603  initialValue = elementsAttr;
604 
605  // For scalar memrefs, the global variable created is of the element type,
606  // so unpack the elements attribute to extract the value.
607  if (type.getRank() == 0)
608  initialValue = elementsAttr.getSplatValue<Attribute>();
609  }
610 
611  uint64_t alignment = global.getAlignment().value_or(0);
612 
613  auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
614  global, arrayTy, global.getConstant(), linkage, global.getSymName(),
615  initialValue, alignment, type.getMemorySpaceAsInt());
616  if (!global.isExternal() && global.isUninitialized()) {
617  Block *blk = new Block();
618  newGlobal.getInitializerRegion().push_back(blk);
619  rewriter.setInsertionPointToStart(blk);
620  Value undef[] = {
621  rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
622  rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
623  }
624  return success();
625  }
626 };
627 
628 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
629 /// the first element stashed into the descriptor. This reuses
630 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
631 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
632  GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
633  : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
634  converter) {}
635 
636  /// Buffer "allocation" for memref.get_global op is getting the address of
637  /// the global variable referenced.
638  std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
639  Location loc, Value sizeBytes,
640  Operation *op) const override {
641  auto getGlobalOp = cast<memref::GetGlobalOp>(op);
642  MemRefType type = getGlobalOp.getResult().getType().cast<MemRefType>();
643  unsigned memSpace = type.getMemorySpaceAsInt();
644 
645  Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
646  auto addressOf = rewriter.create<LLVM::AddressOfOp>(
647  loc, LLVM::LLVMPointerType::get(arrayTy, memSpace),
648  getGlobalOp.getName());
649 
650  // Get the address of the first element in the array by creating a GEP with
651  // the address of the GV as the base, and (rank + 1) number of 0 indices.
652  Type elementType = typeConverter->convertType(type.getElementType());
653  Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
654 
655  auto gep = rewriter.create<LLVM::GEPOp>(
656  loc, elementPtrType, addressOf,
657  SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));
658 
659  // We do not expect the memref obtained using `memref.get_global` to be
660  // ever deallocated. Set the allocated pointer to be known bad value to
661  // help debug if that ever happens.
662  auto intPtrType = getIntPtrType(memSpace);
663  Value deadBeefConst =
664  createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
665  auto deadBeefPtr =
666  rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
667 
668  // Both allocated and aligned pointers are same. We could potentially stash
669  // a nullptr for the allocated pointer since we do not expect any dealloc.
670  return std::make_tuple(deadBeefPtr, gep);
671  }
672 };
673 
674 // Load operation is lowered to obtaining a pointer to the indexed element
675 // and loading it.
676 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
677  using Base::Base;
678 
680  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
681  ConversionPatternRewriter &rewriter) const override {
682  auto type = loadOp.getMemRefType();
683 
684  Value dataPtr =
685  getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
686  adaptor.getIndices(), rewriter);
687  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
688  return success();
689  }
690 };
691 
692 // Store operation is lowered to obtaining a pointer to the indexed element,
693 // and storing the given value to it.
694 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
695  using Base::Base;
696 
698  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
699  ConversionPatternRewriter &rewriter) const override {
700  auto type = op.getMemRefType();
701 
702  Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
703  adaptor.getIndices(), rewriter);
704  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr);
705  return success();
706  }
707 };
708 
709 // The prefetch operation is lowered in a way similar to the load operation
710 // except that the llvm.prefetch operation is used for replacement.
711 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
712  using Base::Base;
713 
715  matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
716  ConversionPatternRewriter &rewriter) const override {
717  auto type = prefetchOp.getMemRefType();
718  auto loc = prefetchOp.getLoc();
719 
720  Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
721  adaptor.getIndices(), rewriter);
722 
723  // Replace with llvm.prefetch.
724  auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
725  auto isWrite = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type,
726  prefetchOp.getIsWrite());
727  auto localityHint = rewriter.create<LLVM::ConstantOp>(
728  loc, llvmI32Type, prefetchOp.getLocalityHint());
729  auto isData = rewriter.create<LLVM::ConstantOp>(
730  loc, llvmI32Type, prefetchOp.getIsDataCache());
731 
732  rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
733  localityHint, isData);
734  return success();
735  }
736 };
737 
738 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
740 
742  matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
743  ConversionPatternRewriter &rewriter) const override {
744  Location loc = op.getLoc();
745  Type operandType = op.getMemref().getType();
746  if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
747  UnrankedMemRefDescriptor desc(adaptor.getMemref());
748  rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
749  return success();
750  }
751  if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
752  rewriter.replaceOp(
753  op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
754  return success();
755  }
756  return failure();
757  }
758 };
759 
760 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
762 
763  LogicalResult match(memref::CastOp memRefCastOp) const override {
764  Type srcType = memRefCastOp.getOperand().getType();
765  Type dstType = memRefCastOp.getType();
766 
767  // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
768  // used for type erasure. For now they must preserve underlying element type
769  // and require source and result type to have the same rank. Therefore,
770  // perform a sanity check that the underlying structs are the same. Once op
771  // semantics are relaxed we can revisit.
772  if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
773  return success(typeConverter->convertType(srcType) ==
774  typeConverter->convertType(dstType));
775 
776  // At least one of the operands is unranked type
777  assert(srcType.isa<UnrankedMemRefType>() ||
778  dstType.isa<UnrankedMemRefType>());
779 
780  // Unranked to unranked cast is disallowed
781  return !(srcType.isa<UnrankedMemRefType>() &&
782  dstType.isa<UnrankedMemRefType>())
783  ? success()
784  : failure();
785  }
786 
787  void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
788  ConversionPatternRewriter &rewriter) const override {
789  auto srcType = memRefCastOp.getOperand().getType();
790  auto dstType = memRefCastOp.getType();
791  auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
792  auto loc = memRefCastOp.getLoc();
793 
794  // For ranked/ranked case, just keep the original descriptor.
795  if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
796  return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
797 
798  if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
799  // Casting ranked to unranked memref type
800  // Set the rank in the destination from the memref type
801  // Allocate space on the stack and copy the src memref descriptor
802  // Set the ptr in the destination to the stack space
803  auto srcMemRefType = srcType.cast<MemRefType>();
804  int64_t rank = srcMemRefType.getRank();
805  // ptr = AllocaOp sizeof(MemRefDescriptor)
806  auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
807  loc, adaptor.getSource(), rewriter);
808  // voidptr = BitCastOp srcType* to void*
809  auto voidPtr =
810  rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
811  .getResult();
812  // rank = ConstantOp srcRank
813  auto rankVal = rewriter.create<LLVM::ConstantOp>(
814  loc, getIndexType(), rewriter.getIndexAttr(rank));
815  // undef = UndefOp
816  UnrankedMemRefDescriptor memRefDesc =
817  UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
818  // d1 = InsertValueOp undef, rank, 0
819  memRefDesc.setRank(rewriter, loc, rankVal);
820  // d2 = InsertValueOp d1, voidptr, 1
821  memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
822  rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
823 
824  } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
825  // Casting from unranked type to ranked.
826  // The operation is assumed to be doing a correct cast. If the destination
827  // type mismatches the unranked the type, it is undefined behavior.
828  UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
829  // ptr = ExtractValueOp src, 1
830  auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
831  // castPtr = BitCastOp i8* to structTy*
832  auto castPtr =
833  rewriter
834  .create<LLVM::BitcastOp>(
835  loc, LLVM::LLVMPointerType::get(targetStructType), ptr)
836  .getResult();
837  // struct = LoadOp castPtr
838  auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
839  rewriter.replaceOp(memRefCastOp, loadOp.getResult());
840  } else {
841  llvm_unreachable("Unsupported unranked memref to unranked memref cast");
842  }
843  }
844 };
845 
846 /// Pattern to lower a `memref.copy` to llvm.
847 ///
848 /// For memrefs with identity layouts, the copy is lowered to the llvm
849 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
850 /// to the generic `MemrefCopyFn`.
851 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
853 
855  lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
856  ConversionPatternRewriter &rewriter) const {
857  auto loc = op.getLoc();
858  auto srcType = op.getSource().getType().dyn_cast<MemRefType>();
859 
860  MemRefDescriptor srcDesc(adaptor.getSource());
861 
862  // Compute number of elements.
863  Value numElements = rewriter.create<LLVM::ConstantOp>(
864  loc, getIndexType(), rewriter.getIndexAttr(1));
865  for (int pos = 0; pos < srcType.getRank(); ++pos) {
866  auto size = srcDesc.size(rewriter, loc, pos);
867  numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
868  }
869 
870  // Get element size.
871  auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
872  // Compute total.
873  Value totalSize =
874  rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
875 
876  Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
877  Value srcOffset = srcDesc.offset(rewriter, loc);
878  Value srcPtr = rewriter.create<LLVM::GEPOp>(loc, srcBasePtr.getType(),
879  srcBasePtr, srcOffset);
880  MemRefDescriptor targetDesc(adaptor.getTarget());
881  Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
882  Value targetOffset = targetDesc.offset(rewriter, loc);
883  Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(),
884  targetBasePtr, targetOffset);
885  Value isVolatile =
886  rewriter.create<LLVM::ConstantOp>(loc, rewriter.getBoolAttr(false));
887  rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
888  isVolatile);
889  rewriter.eraseOp(op);
890 
891  return success();
892  }
893 
895  lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
896  ConversionPatternRewriter &rewriter) const {
897  auto loc = op.getLoc();
898  auto srcType = op.getSource().getType().cast<BaseMemRefType>();
899  auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
900 
901  // First make sure we have an unranked memref descriptor representation.
902  auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
903  auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
904  type.getRank());
905  auto *typeConverter = getTypeConverter();
906  auto ptr =
907  typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
908  auto voidPtr =
909  rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
910  .getResult();
911  auto unrankedType =
912  UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
913  return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter,
914  unrankedType,
915  ValueRange{rank, voidPtr});
916  };
917 
918  Value unrankedSource = srcType.hasRank()
919  ? makeUnranked(adaptor.getSource(), srcType)
920  : adaptor.getSource();
921  Value unrankedTarget = targetType.hasRank()
922  ? makeUnranked(adaptor.getTarget(), targetType)
923  : adaptor.getTarget();
924 
925  // Now promote the unranked descriptors to the stack.
926  auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
927  rewriter.getIndexAttr(1));
928  auto promote = [&](Value desc) {
929  auto ptrType = LLVM::LLVMPointerType::get(desc.getType());
930  auto allocated =
931  rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one});
932  rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
933  return allocated;
934  };
935 
936  auto sourcePtr = promote(unrankedSource);
937  auto targetPtr = promote(unrankedTarget);
938 
939  unsigned typeSize =
940  mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType());
941  auto elemSize = rewriter.create<LLVM::ConstantOp>(
942  loc, getIndexType(), rewriter.getIndexAttr(typeSize));
943  auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
944  op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
945  rewriter.create<LLVM::CallOp>(loc, copyFn,
946  ValueRange{elemSize, sourcePtr, targetPtr});
947  rewriter.eraseOp(op);
948 
949  return success();
950  }
951 
953  matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
954  ConversionPatternRewriter &rewriter) const override {
955  auto srcType = op.getSource().getType().cast<BaseMemRefType>();
956  auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
957 
958  auto isContiguousMemrefType = [](BaseMemRefType type) {
959  auto memrefType = type.dyn_cast<mlir::MemRefType>();
960  // We can use memcpy for memrefs if they have an identity layout or are
961  // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
962  // special case handled by memrefCopy.
963  return memrefType &&
964  (memrefType.getLayout().isIdentity() ||
965  (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
967  };
968 
969  if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
970  return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
971 
972  return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
973  }
974 };
975 
976 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
977 /// memref type. In unranked case, the fields are extracted from the underlying
978 /// ranked descriptor.
979 static void extractPointersAndOffset(Location loc,
980  ConversionPatternRewriter &rewriter,
981  LLVMTypeConverter &typeConverter,
982  Value originalOperand,
983  Value convertedOperand,
984  Value *allocatedPtr, Value *alignedPtr,
985  Value *offset = nullptr) {
986  Type operandType = originalOperand.getType();
987  if (operandType.isa<MemRefType>()) {
988  MemRefDescriptor desc(convertedOperand);
989  *allocatedPtr = desc.allocatedPtr(rewriter, loc);
990  *alignedPtr = desc.alignedPtr(rewriter, loc);
991  if (offset != nullptr)
992  *offset = desc.offset(rewriter, loc);
993  return;
994  }
995 
996  unsigned memorySpace =
998  Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
999  Type llvmElementType = typeConverter.convertType(elementType);
1000  Type elementPtrPtrType = LLVM::LLVMPointerType::get(
1001  LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
1002 
1003  // Extract pointer to the underlying ranked memref descriptor and cast it to
1004  // ElemType**.
1005  UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
1006  Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1007 
1009  rewriter, loc, underlyingDescPtr, elementPtrPtrType);
1011  rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
1012  if (offset != nullptr) {
1014  rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
1015  }
1016 }
1017 
1018 struct MemRefReinterpretCastOpLowering
1019  : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1020  using ConvertOpToLLVMPattern<
1021  memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1022 
1024  matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1025  ConversionPatternRewriter &rewriter) const override {
1026  Type srcType = castOp.getSource().getType();
1027 
1028  Value descriptor;
1029  if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1030  adaptor, &descriptor)))
1031  return failure();
1032  rewriter.replaceOp(castOp, {descriptor});
1033  return success();
1034  }
1035 
1036 private:
1037  LogicalResult convertSourceMemRefToDescriptor(
1038  ConversionPatternRewriter &rewriter, Type srcType,
1039  memref::ReinterpretCastOp castOp,
1040  memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1041  MemRefType targetMemRefType =
1042  castOp.getResult().getType().cast<MemRefType>();
1043  auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
1044  .dyn_cast_or_null<LLVM::LLVMStructType>();
1045  if (!llvmTargetDescriptorTy)
1046  return failure();
1047 
1048  // Create descriptor.
1049  Location loc = castOp.getLoc();
1050  auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1051 
1052  // Set allocated and aligned pointers.
1053  Value allocatedPtr, alignedPtr;
1054  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1055  castOp.getSource(), adaptor.getSource(),
1056  &allocatedPtr, &alignedPtr);
1057  desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1058  desc.setAlignedPtr(rewriter, loc, alignedPtr);
1059 
1060  // Set offset.
1061  if (castOp.isDynamicOffset(0))
1062  desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1063  else
1064  desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1065 
1066  // Set sizes and strides.
1067  unsigned dynSizeId = 0;
1068  unsigned dynStrideId = 0;
1069  for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1070  if (castOp.isDynamicSize(i))
1071  desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1072  else
1073  desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1074 
1075  if (castOp.isDynamicStride(i))
1076  desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1077  else
1078  desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1079  }
1080  *descriptor = desc;
1081  return success();
1082  }
1083 };
1084 
1085 struct MemRefReshapeOpLowering
1086  : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1088 
1090  matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1091  ConversionPatternRewriter &rewriter) const override {
1092  Type srcType = reshapeOp.getSource().getType();
1093 
1094  Value descriptor;
1095  if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1096  adaptor, &descriptor)))
1097  return failure();
1098  rewriter.replaceOp(reshapeOp, {descriptor});
1099  return success();
1100  }
1101 
1102 private:
1104  convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1105  Type srcType, memref::ReshapeOp reshapeOp,
1106  memref::ReshapeOp::Adaptor adaptor,
1107  Value *descriptor) const {
1108  auto shapeMemRefType = reshapeOp.getShape().getType().cast<MemRefType>();
1109  if (shapeMemRefType.hasStaticShape()) {
1110  MemRefType targetMemRefType =
1111  reshapeOp.getResult().getType().cast<MemRefType>();
1112  auto llvmTargetDescriptorTy =
1113  typeConverter->convertType(targetMemRefType)
1114  .dyn_cast_or_null<LLVM::LLVMStructType>();
1115  if (!llvmTargetDescriptorTy)
1116  return failure();
1117 
1118  // Create descriptor.
1119  Location loc = reshapeOp.getLoc();
1120  auto desc =
1121  MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1122 
1123  // Set allocated and aligned pointers.
1124  Value allocatedPtr, alignedPtr;
1125  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1126  reshapeOp.getSource(), adaptor.getSource(),
1127  &allocatedPtr, &alignedPtr);
1128  desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1129  desc.setAlignedPtr(rewriter, loc, alignedPtr);
1130 
1131  // Extract the offset and strides from the type.
1132  int64_t offset;
1133  SmallVector<int64_t> strides;
1134  if (failed(getStridesAndOffset(targetMemRefType, strides, offset)))
1135  return rewriter.notifyMatchFailure(
1136  reshapeOp, "failed to get stride and offset exprs");
1137 
1138  if (!isStaticStrideOrOffset(offset))
1139  return rewriter.notifyMatchFailure(reshapeOp,
1140  "dynamic offset is unsupported");
1141 
1142  desc.setConstantOffset(rewriter, loc, offset);
1143 
1144  assert(targetMemRefType.getLayout().isIdentity() &&
1145  "Identity layout map is a precondition of a valid reshape op");
1146 
1147  Value stride = nullptr;
1148  int64_t targetRank = targetMemRefType.getRank();
1149  for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1150  if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1151  // If the stride for this dimension is dynamic, then use the product
1152  // of the sizes of the inner dimensions.
1153  stride = createIndexConstant(rewriter, loc, strides[i]);
1154  } else if (!stride) {
1155  // `stride` is null only in the first iteration of the loop. However,
1156  // since the target memref has an identity layout, we can safely set
1157  // the innermost stride to 1.
1158  stride = createIndexConstant(rewriter, loc, 1);
1159  }
1160 
1161  Value dimSize;
1162  int64_t size = targetMemRefType.getDimSize(i);
1163  // If the size of this dimension is dynamic, then load it at runtime
1164  // from the shape operand.
1165  if (!ShapedType::isDynamic(size)) {
1166  dimSize = createIndexConstant(rewriter, loc, size);
1167  } else {
1168  Value shapeOp = reshapeOp.getShape();
1169  Value index = createIndexConstant(rewriter, loc, i);
1170  dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
1171  Type indexType = getIndexType();
1172  if (dimSize.getType() != indexType)
1173  dimSize = typeConverter->materializeTargetConversion(
1174  rewriter, loc, indexType, dimSize);
1175  assert(dimSize && "Invalid memref element type");
1176  }
1177 
1178  desc.setSize(rewriter, loc, i, dimSize);
1179  desc.setStride(rewriter, loc, i, stride);
1180 
1181  // Prepare the stride value for the next dimension.
1182  stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
1183  }
1184 
1185  *descriptor = desc;
1186  return success();
1187  }
1188 
1189  // The shape is a rank-1 tensor with unknown length.
1190  Location loc = reshapeOp.getLoc();
1191  MemRefDescriptor shapeDesc(adaptor.getShape());
1192  Value resultRank = shapeDesc.size(rewriter, loc, 0);
1193 
1194  // Extract address space and element type.
1195  auto targetType =
1196  reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
1197  unsigned addressSpace = targetType.getMemorySpaceAsInt();
1198  Type elementType = targetType.getElementType();
1199 
1200  // Create the unranked memref descriptor that holds the ranked one. The
1201  // inner descriptor is allocated on stack.
1202  auto targetDesc = UnrankedMemRefDescriptor::undef(
1203  rewriter, loc, typeConverter->convertType(targetType));
1204  targetDesc.setRank(rewriter, loc, resultRank);
1205  SmallVector<Value, 4> sizes;
1206  UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1207  targetDesc, sizes);
1208  Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1209  loc, getVoidPtrType(), sizes.front(), llvm::None);
1210  targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1211 
1212  // Extract pointers and offset from the source memref.
1213  Value allocatedPtr, alignedPtr, offset;
1214  extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1215  reshapeOp.getSource(), adaptor.getSource(),
1216  &allocatedPtr, &alignedPtr, &offset);
1217 
1218  // Set pointers and offset.
1219  Type llvmElementType = typeConverter->convertType(elementType);
1220  auto elementPtrPtrType = LLVM::LLVMPointerType::get(
1221  LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
1222  UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1223  elementPtrPtrType, allocatedPtr);
1224  UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1225  underlyingDescPtr,
1226  elementPtrPtrType, alignedPtr);
1227  UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1228  underlyingDescPtr, elementPtrPtrType,
1229  offset);
1230 
1231  // Use the offset pointer as base for further addressing. Copy over the new
1232  // shape and compute strides. For this, we create a loop from rank-1 to 0.
1233  Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1234  rewriter, loc, *getTypeConverter(), underlyingDescPtr,
1235  elementPtrPtrType);
1236  Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1237  rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1238  Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1239  Value oneIndex = createIndexConstant(rewriter, loc, 1);
1240  Value resultRankMinusOne =
1241  rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1242 
1243  Block *initBlock = rewriter.getInsertionBlock();
1244  Type indexType = getTypeConverter()->getIndexType();
1245  Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1246 
1247  Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1248  {indexType, indexType}, {loc, loc});
1249 
1250  // Move the remaining initBlock ops to condBlock.
1251  Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1252  rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1253 
1254  rewriter.setInsertionPointToEnd(initBlock);
1255  rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1256  condBlock);
1257  rewriter.setInsertionPointToStart(condBlock);
1258  Value indexArg = condBlock->getArgument(0);
1259  Value strideArg = condBlock->getArgument(1);
1260 
1261  Value zeroIndex = createIndexConstant(rewriter, loc, 0);
1262  Value pred = rewriter.create<LLVM::ICmpOp>(
1263  loc, IntegerType::get(rewriter.getContext(), 1),
1264  LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1265 
1266  Block *bodyBlock =
1267  rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1268  rewriter.setInsertionPointToStart(bodyBlock);
1269 
1270  // Copy size from shape to descriptor.
1271  Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
1272  Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(loc, llvmIndexPtrType,
1273  shapeOperandPtr, indexArg);
1274  Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
1275  UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1276  targetSizesBase, indexArg, size);
1277 
1278  // Write stride value and compute next one.
1279  UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1280  targetStridesBase, indexArg, strideArg);
1281  Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1282 
1283  // Decrement loop counter and branch back.
1284  Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1285  rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1286  condBlock);
1287 
1288  Block *remainder =
1289  rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1290 
1291  // Hook up the cond exit to the remainder.
1292  rewriter.setInsertionPointToEnd(condBlock);
1293  rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
1294  llvm::None);
1295 
1296  // Reset position to beginning of new remainder block.
1297  rewriter.setInsertionPointToStart(remainder);
1298 
1299  *descriptor = targetDesc;
1300  return success();
1301  }
1302 };
1303 
1304 /// Helper function to convert a vector of `OpFoldResult`s into a vector of
1305 /// `Value`s.
1307  Type &llvmIndexType,
1308  ArrayRef<OpFoldResult> valueOrAttrVec) {
1309  return llvm::to_vector<4>(
1310  llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
1311  if (auto attr = value.dyn_cast<Attribute>())
1312  return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr);
1313  return value.get<Value>();
1314  }));
1315 }
1316 
1317 /// Compute a map that for a given dimension of the expanded type gives the
1318 /// dimension in the collapsed type it maps to. Essentially its the inverse of
1319 /// the `reassocation` maps.
1322  llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
1323  for (auto &en : enumerate(reassociation)) {
1324  for (auto dim : en.value())
1325  expandedDimToCollapsedDim[dim] = en.index();
1326  }
1327  return expandedDimToCollapsedDim;
1328 }
1329 
1330 static OpFoldResult
1331 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType,
1332  int64_t outDimIndex, ArrayRef<int64_t> outStaticShape,
1333  MemRefDescriptor &inDesc,
1334  ArrayRef<int64_t> inStaticShape,
1335  ArrayRef<ReassociationIndices> reassocation,
1336  DenseMap<int64_t, int64_t> &outDimToInDimMap) {
1337  int64_t outDimSize = outStaticShape[outDimIndex];
1338  if (!ShapedType::isDynamic(outDimSize))
1339  return b.getIndexAttr(outDimSize);
1340 
1341  // Calculate the multiplication of all the out dim sizes except the
1342  // current dim.
1343  int64_t inDimIndex = outDimToInDimMap[outDimIndex];
1344  int64_t otherDimSizesMul = 1;
1345  for (auto otherDimIndex : reassocation[inDimIndex]) {
1346  if (otherDimIndex == static_cast<unsigned>(outDimIndex))
1347  continue;
1348  int64_t otherDimSize = outStaticShape[otherDimIndex];
1349  assert(!ShapedType::isDynamic(otherDimSize) &&
1350  "single dimension cannot be expanded into multiple dynamic "
1351  "dimensions");
1352  otherDimSizesMul *= otherDimSize;
1353  }
1354 
1355  // outDimSize = inDimSize / otherOutDimSizesMul
1356  int64_t inDimSize = inStaticShape[inDimIndex];
1357  Value inDimSizeDynamic =
1358  ShapedType::isDynamic(inDimSize)
1359  ? inDesc.size(b, loc, inDimIndex)
1360  : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1361  b.getIndexAttr(inDimSize));
1362  Value outDimSizeDynamic = b.create<LLVM::SDivOp>(
1363  loc, inDimSizeDynamic,
1364  b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1365  b.getIndexAttr(otherDimSizesMul)));
1366  return outDimSizeDynamic;
1367 }
1368 
1369 static OpFoldResult getCollapsedOutputDimSize(
1370  OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex,
1371  int64_t outDimSize, ArrayRef<int64_t> inStaticShape,
1372  MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) {
1373  if (!ShapedType::isDynamic(outDimSize))
1374  return b.getIndexAttr(outDimSize);
1375 
1376  Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1));
1377  Value outDimSizeDynamic = c1;
1378  for (auto inDimIndex : reassocation[outDimIndex]) {
1379  int64_t inDimSize = inStaticShape[inDimIndex];
1380  Value inDimSizeDynamic =
1381  ShapedType::isDynamic(inDimSize)
1382  ? inDesc.size(b, loc, inDimIndex)
1383  : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1384  b.getIndexAttr(inDimSize));
1385  outDimSizeDynamic =
1386  b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic);
1387  }
1388  return outDimSizeDynamic;
1389 }
1390 
1392 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1393  ArrayRef<ReassociationIndices> reassociation,
1394  ArrayRef<int64_t> inStaticShape,
1395  MemRefDescriptor &inDesc,
1396  ArrayRef<int64_t> outStaticShape) {
1397  return llvm::to_vector<4>(llvm::map_range(
1398  llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1399  return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1400  outStaticShape[outDimIndex],
1401  inStaticShape, inDesc, reassociation);
1402  }));
1403 }
1404 
1406 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1407  ArrayRef<ReassociationIndices> reassociation,
1408  ArrayRef<int64_t> inStaticShape,
1409  MemRefDescriptor &inDesc,
1410  ArrayRef<int64_t> outStaticShape) {
1411  DenseMap<int64_t, int64_t> outDimToInDimMap =
1412  getExpandedDimToCollapsedDimMap(reassociation);
1413  return llvm::to_vector<4>(llvm::map_range(
1414  llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1415  return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1416  outStaticShape, inDesc, inStaticShape,
1417  reassociation, outDimToInDimMap);
1418  }));
1419 }
1420 
1421 static SmallVector<Value>
1422 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1423  ArrayRef<ReassociationIndices> reassociation,
1424  ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc,
1425  ArrayRef<int64_t> outStaticShape) {
1426  return outStaticShape.size() < inStaticShape.size()
1427  ? getAsValues(b, loc, llvmIndexType,
1428  getCollapsedOutputShape(b, loc, llvmIndexType,
1429  reassociation, inStaticShape,
1430  inDesc, outStaticShape))
1431  : getAsValues(b, loc, llvmIndexType,
1432  getExpandedOutputShape(b, loc, llvmIndexType,
1433  reassociation, inStaticShape,
1434  inDesc, outStaticShape));
1435 }
1436 
1437 static void fillInStridesForExpandedMemDescriptor(
1438  OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc,
1439  MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
1440  // See comments for computeExpandedLayoutMap for details on how the strides
1441  // are calculated.
1442  for (auto &en : llvm::enumerate(reassociation)) {
1443  auto currentStrideToExpand = srcDesc.stride(b, loc, en.index());
1444  for (auto dstIndex : llvm::reverse(en.value())) {
1445  dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand);
1446  Value size = dstDesc.size(b, loc, dstIndex);
1447  currentStrideToExpand =
1448  b.create<LLVM::MulOp>(loc, size, currentStrideToExpand);
1449  }
1450  }
1451 }
1452 
1453 static void fillInStridesForCollapsedMemDescriptor(
1454  ConversionPatternRewriter &rewriter, Location loc, Operation *op,
1455  TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc,
1456  MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
1457  // See comments for computeCollapsedLayoutMap for details on how the strides
1458  // are calculated.
1459  auto srcShape = srcType.getShape();
1460  for (auto &en : llvm::enumerate(reassociation)) {
1461  rewriter.setInsertionPoint(op);
1462  auto dstIndex = en.index();
1463  ArrayRef<int64_t> ref = llvm::makeArrayRef(en.value());
1464  while (srcShape[ref.back()] == 1 && ref.size() > 1)
1465  ref = ref.drop_back();
1466  if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
1467  dstDesc.setStride(rewriter, loc, dstIndex,
1468  srcDesc.stride(rewriter, loc, ref.back()));
1469  } else {
1470  // Iterate over the source strides in reverse order. Skip over the
1471  // dimensions whose size is 1.
1472  // TODO: we should take the minimum stride in the reassociation group
1473  // instead of just the first where the dimension is not 1.
1474  //
1475  // +------------------------------------------------------+
1476  // | curEntry: |
1477  // | %srcStride = strides[srcIndex] |
1478  // | %neOne = cmp sizes[srcIndex],1 +--+
1479  // | cf.cond_br %neOne, continue(%srcStride), nextEntry | |
1480  // +-------------------------+----------------------------+ |
1481  // | |
1482  // v |
1483  // +-----------------------------+ |
1484  // | nextEntry: | |
1485  // | ... +---+ |
1486  // +--------------+--------------+ | |
1487  // | | |
1488  // v | |
1489  // +-----------------------------+ | |
1490  // | nextEntry: | | |
1491  // | ... | | |
1492  // +--------------+--------------+ | +--------+
1493  // | | |
1494  // v v v
1495  // +--------------------------------------------------+
1496  // | continue(%newStride): |
1497  // | %newMemRefDes = setStride(%newStride,dstIndex) |
1498  // +--------------------------------------------------+
1499  OpBuilder::InsertionGuard guard(rewriter);
1500  Block *initBlock = rewriter.getInsertionBlock();
1501  Block *continueBlock =
1502  rewriter.splitBlock(initBlock, rewriter.getInsertionPoint());
1503  continueBlock->insertArgument(unsigned(0), srcDesc.getIndexType(), loc);
1504  rewriter.setInsertionPointToStart(continueBlock);
1505  dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0));
1506 
1507  Block *curEntryBlock = initBlock;
1508  Block *nextEntryBlock;
1509  for (auto srcIndex : llvm::reverse(ref)) {
1510  if (srcShape[srcIndex] == 1 && srcIndex != ref.front())
1511  continue;
1512  rewriter.setInsertionPointToEnd(curEntryBlock);
1513  Value srcStride = srcDesc.stride(rewriter, loc, srcIndex);
1514  if (srcIndex == ref.front()) {
1515  rewriter.create<LLVM::BrOp>(loc, srcStride, continueBlock);
1516  break;
1517  }
1518  Value one = rewriter.create<LLVM::ConstantOp>(
1519  loc, rewriter.getI64Type(), rewriter.getI32IntegerAttr(1));
1520  Value predNeOne = rewriter.create<LLVM::ICmpOp>(
1521  loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex),
1522  one);
1523  {
1524  OpBuilder::InsertionGuard guard(rewriter);
1525  nextEntryBlock = rewriter.createBlock(
1526  initBlock->getParent(), Region::iterator(continueBlock), {});
1527  }
1528  rewriter.create<LLVM::CondBrOp>(loc, predNeOne, continueBlock,
1529  srcStride, nextEntryBlock, llvm::None);
1530  curEntryBlock = nextEntryBlock;
1531  }
1532  }
1533  }
1534 }
1535 
1536 static void fillInDynamicStridesForMemDescriptor(
1538  TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType,
1539  MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc,
1540  ArrayRef<ReassociationIndices> reassociation) {
1541  if (srcType.getRank() > dstType.getRank())
1542  fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType,
1543  srcDesc, dstDesc, reassociation);
1544  else
1545  fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc,
1546  reassociation);
1547 }
1548 
1549 // ReshapeOp creates a new view descriptor of the proper rank.
1550 // For now, the only conversion supported is for target MemRef with static sizes
1551 // and strides.
1552 template <typename ReshapeOp>
1553 class ReassociatingReshapeOpConversion
1554  : public ConvertOpToLLVMPattern<ReshapeOp> {
1555 public:
1557  using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1558 
1560  matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1561  ConversionPatternRewriter &rewriter) const override {
1562  MemRefType dstType = reshapeOp.getResultType();
1563  MemRefType srcType = reshapeOp.getSrcType();
1564 
1565  int64_t offset;
1566  SmallVector<int64_t, 4> strides;
1567  if (failed(getStridesAndOffset(dstType, strides, offset))) {
1568  return rewriter.notifyMatchFailure(
1569  reshapeOp, "failed to get stride and offset exprs");
1570  }
1571 
1572  MemRefDescriptor srcDesc(adaptor.getSrc());
1573  Location loc = reshapeOp->getLoc();
1574  auto dstDesc = MemRefDescriptor::undef(
1575  rewriter, loc, this->typeConverter->convertType(dstType));
1576  dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc));
1577  dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc));
1578  dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc));
1579 
1580  ArrayRef<int64_t> srcStaticShape = srcType.getShape();
1581  ArrayRef<int64_t> dstStaticShape = dstType.getShape();
1582  Type llvmIndexType =
1583  this->typeConverter->convertType(rewriter.getIndexType());
1584  SmallVector<Value> dstShape = getDynamicOutputShape(
1585  rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(),
1586  srcStaticShape, srcDesc, dstStaticShape);
1587  for (auto &en : llvm::enumerate(dstShape))
1588  dstDesc.setSize(rewriter, loc, en.index(), en.value());
1589 
1590  if (llvm::all_of(strides, isStaticStrideOrOffset)) {
1591  for (auto &en : llvm::enumerate(strides))
1592  dstDesc.setConstantStride(rewriter, loc, en.index(), en.value());
1593  } else if (srcType.getLayout().isIdentity() &&
1594  dstType.getLayout().isIdentity()) {
1595  Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
1596  rewriter.getIndexAttr(1));
1597  Value stride = c1;
1598  for (auto dimIndex :
1599  llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) {
1600  dstDesc.setStride(rewriter, loc, dimIndex, stride);
1601  stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride);
1602  }
1603  } else {
1604  // There could be mixed static/dynamic strides. For simplicity, we
1605  // recompute all strides if there is at least one dynamic stride.
1606  fillInDynamicStridesForMemDescriptor(
1607  rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType,
1608  srcDesc, dstDesc, reshapeOp.getReassociationIndices());
1609  }
1610  rewriter.replaceOp(reshapeOp, {dstDesc});
1611  return success();
1612  }
1613 };
1614 
1615 /// Conversion pattern that transforms a subview op into:
1616 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor
1617 /// 2. Updates to the descriptor to introduce the data ptr, offset, size
1618 /// and stride.
1619 /// The subview op is replaced by the descriptor.
1620 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1622 
1624  matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1625  ConversionPatternRewriter &rewriter) const override {
1626  auto loc = subViewOp.getLoc();
1627 
1628  auto sourceMemRefType = subViewOp.getSource().getType().cast<MemRefType>();
1629  auto sourceElementTy =
1630  typeConverter->convertType(sourceMemRefType.getElementType());
1631 
1632  auto viewMemRefType = subViewOp.getType();
1633  auto inferredType =
1634  memref::SubViewOp::inferResultType(
1635  subViewOp.getSourceType(),
1636  extractFromI64ArrayAttr(subViewOp.getStaticOffsets()),
1637  extractFromI64ArrayAttr(subViewOp.getStaticSizes()),
1638  extractFromI64ArrayAttr(subViewOp.getStaticStrides()))
1639  .cast<MemRefType>();
1640  auto targetElementTy =
1641  typeConverter->convertType(viewMemRefType.getElementType());
1642  auto targetDescTy = typeConverter->convertType(viewMemRefType);
1643  if (!sourceElementTy || !targetDescTy || !targetElementTy ||
1644  !LLVM::isCompatibleType(sourceElementTy) ||
1645  !LLVM::isCompatibleType(targetElementTy) ||
1646  !LLVM::isCompatibleType(targetDescTy))
1647  return failure();
1648 
1649  // Extract the offset and strides from the type.
1650  int64_t offset;
1651  SmallVector<int64_t, 4> strides;
1652  auto successStrides = getStridesAndOffset(inferredType, strides, offset);
1653  if (failed(successStrides))
1654  return failure();
1655 
1656  // Create the descriptor.
1657  if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1658  return failure();
1659  MemRefDescriptor sourceMemRef(adaptor.getOperands().front());
1660  auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1661 
1662  // Copy the buffer pointer from the old descriptor to the new one.
1663  Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
1664  Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1665  loc,
1666  LLVM::LLVMPointerType::get(targetElementTy,
1667  viewMemRefType.getMemorySpaceAsInt()),
1668  extracted);
1669  targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1670 
1671  // Copy the aligned pointer from the old descriptor to the new one.
1672  extracted = sourceMemRef.alignedPtr(rewriter, loc);
1673  bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1674  loc,
1675  LLVM::LLVMPointerType::get(targetElementTy,
1676  viewMemRefType.getMemorySpaceAsInt()),
1677  extracted);
1678  targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1679 
1680  size_t inferredShapeRank = inferredType.getRank();
1681  size_t resultShapeRank = viewMemRefType.getRank();
1682 
1683  // Extract strides needed to compute offset.
1684  SmallVector<Value, 4> strideValues;
1685  strideValues.reserve(inferredShapeRank);
1686  for (unsigned i = 0; i < inferredShapeRank; ++i)
1687  strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
1688 
1689  // Offset.
1690  auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1691  if (!ShapedType::isDynamicStrideOrOffset(offset)) {
1692  targetMemRef.setConstantOffset(rewriter, loc, offset);
1693  } else {
1694  Value baseOffset = sourceMemRef.offset(rewriter, loc);
1695  // `inferredShapeRank` may be larger than the number of offset operands
1696  // because of trailing semantics. In this case, the offset is guaranteed
1697  // to be interpreted as 0 and we can just skip the extra dimensions.
1698  for (unsigned i = 0, e = std::min(inferredShapeRank,
1699  subViewOp.getMixedOffsets().size());
1700  i < e; ++i) {
1701  Value offset =
1702  // TODO: need OpFoldResult ODS adaptor to clean this up.
1703  subViewOp.isDynamicOffset(i)
1704  ? adaptor.getOperands()[subViewOp.getIndexOfDynamicOffset(i)]
1705  : rewriter.create<LLVM::ConstantOp>(
1706  loc, llvmIndexType,
1707  rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i)));
1708  Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]);
1709  baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul);
1710  }
1711  targetMemRef.setOffset(rewriter, loc, baseOffset);
1712  }
1713 
1714  // Update sizes and strides.
1715  SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
1716  SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
1717  assert(mixedSizes.size() == mixedStrides.size() &&
1718  "expected sizes and strides of equal length");
1719  llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
1720  for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
1721  i >= 0 && j >= 0; --i) {
1722  if (unusedDims.test(i))
1723  continue;
1724 
1725  // `i` may overflow subViewOp.getMixedSizes because of trailing semantics.
1726  // In this case, the size is guaranteed to be interpreted as Dim and the
1727  // stride as 1.
1728  Value size, stride;
1729  if (static_cast<unsigned>(i) >= mixedSizes.size()) {
1730  // If the static size is available, use it directly. This is similar to
1731  // the folding of dim(constant-op) but removes the need for dim to be
1732  // aware of LLVM constants and for this pass to be aware of std
1733  // constants.
1734  int64_t staticSize =
1735  subViewOp.getSource().getType().cast<MemRefType>().getShape()[i];
1736  if (staticSize != ShapedType::kDynamicSize) {
1737  size = rewriter.create<LLVM::ConstantOp>(
1738  loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize));
1739  } else {
1740  Value pos = rewriter.create<LLVM::ConstantOp>(
1741  loc, llvmIndexType, rewriter.getI64IntegerAttr(i));
1742  Value dim =
1743  rewriter.create<memref::DimOp>(loc, subViewOp.getSource(), pos);
1744  auto cast = rewriter.create<UnrealizedConversionCastOp>(
1745  loc, llvmIndexType, dim);
1746  size = cast.getResult(0);
1747  }
1748  stride = rewriter.create<LLVM::ConstantOp>(
1749  loc, llvmIndexType, rewriter.getI64IntegerAttr(1));
1750  } else {
1751  // TODO: need OpFoldResult ODS adaptor to clean this up.
1752  size =
1753  subViewOp.isDynamicSize(i)
1754  ? adaptor.getOperands()[subViewOp.getIndexOfDynamicSize(i)]
1755  : rewriter.create<LLVM::ConstantOp>(
1756  loc, llvmIndexType,
1757  rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
1758  if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1759  stride = rewriter.create<LLVM::ConstantOp>(
1760  loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
1761  } else {
1762  stride =
1763  subViewOp.isDynamicStride(i)
1764  ? adaptor.getOperands()[subViewOp.getIndexOfDynamicStride(i)]
1765  : rewriter.create<LLVM::ConstantOp>(
1766  loc, llvmIndexType,
1767  rewriter.getI64IntegerAttr(
1768  subViewOp.getStaticStride(i)));
1769  stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
1770  }
1771  }
1772  targetMemRef.setSize(rewriter, loc, j, size);
1773  targetMemRef.setStride(rewriter, loc, j, stride);
1774  j--;
1775  }
1776 
1777  rewriter.replaceOp(subViewOp, {targetMemRef});
1778  return success();
1779  }
1780 };
1781 
1782 /// Conversion pattern that transforms a transpose op into:
1783 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
1784 /// 2. A load of the ViewDescriptor from the pointer allocated in 1.
1785 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1786 /// and stride. Size and stride are permutations of the original values.
1787 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1788 /// The transpose op is replaced by the alloca'ed pointer.
1789 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1790 public:
1792 
1794  matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1795  ConversionPatternRewriter &rewriter) const override {
1796  auto loc = transposeOp.getLoc();
1797  MemRefDescriptor viewMemRef(adaptor.getIn());
1798 
1799  // No permutation, early exit.
1800  if (transposeOp.getPermutation().isIdentity())
1801  return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1802 
1803  auto targetMemRef = MemRefDescriptor::undef(
1804  rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
1805 
1806  // Copy the base and aligned pointers from the old descriptor to the new
1807  // one.
1808  targetMemRef.setAllocatedPtr(rewriter, loc,
1809  viewMemRef.allocatedPtr(rewriter, loc));
1810  targetMemRef.setAlignedPtr(rewriter, loc,
1811  viewMemRef.alignedPtr(rewriter, loc));
1812 
1813  // Copy the offset pointer from the old descriptor to the new one.
1814  targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1815 
1816  // Iterate over the dimensions and apply size/stride permutation.
1817  for (const auto &en :
1818  llvm::enumerate(transposeOp.getPermutation().getResults())) {
1819  int sourcePos = en.index();
1820  int targetPos = en.value().cast<AffineDimExpr>().getPosition();
1821  targetMemRef.setSize(rewriter, loc, targetPos,
1822  viewMemRef.size(rewriter, loc, sourcePos));
1823  targetMemRef.setStride(rewriter, loc, targetPos,
1824  viewMemRef.stride(rewriter, loc, sourcePos));
1825  }
1826 
1827  rewriter.replaceOp(transposeOp, {targetMemRef});
1828  return success();
1829  }
1830 };
1831 
1832 /// Conversion pattern that transforms an op into:
1833 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor
1834 /// 2. Updates to the descriptor to introduce the data ptr, offset, size
1835 /// and stride.
1836 /// The view op is replaced by the descriptor.
1837 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1839 
1840  // Build and return the value for the idx^th shape dimension, either by
1841  // returning the constant shape dimension or counting the proper dynamic size.
1842  Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1843  ArrayRef<int64_t> shape, ValueRange dynamicSizes,
1844  unsigned idx) const {
1845  assert(idx < shape.size());
1846  if (!ShapedType::isDynamic(shape[idx]))
1847  return createIndexConstant(rewriter, loc, shape[idx]);
1848  // Count the number of dynamic dims in range [0, idx]
1849  unsigned nDynamic =
1850  llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1851  return dynamicSizes[nDynamic];
1852  }
1853 
1854  // Build and return the idx^th stride, either by returning the constant stride
1855  // or by computing the dynamic stride from the current `runningStride` and
1856  // `nextSize`. The caller should keep a running stride and update it with the
1857  // result returned by this function.
1858  Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1859  ArrayRef<int64_t> strides, Value nextSize,
1860  Value runningStride, unsigned idx) const {
1861  assert(idx < strides.size());
1862  if (!ShapedType::isDynamicStrideOrOffset(strides[idx]))
1863  return createIndexConstant(rewriter, loc, strides[idx]);
1864  if (nextSize)
1865  return runningStride
1866  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1867  : nextSize;
1868  assert(!runningStride);
1869  return createIndexConstant(rewriter, loc, 1);
1870  }
1871 
1873  matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1874  ConversionPatternRewriter &rewriter) const override {
1875  auto loc = viewOp.getLoc();
1876 
1877  auto viewMemRefType = viewOp.getType();
1878  auto targetElementTy =
1879  typeConverter->convertType(viewMemRefType.getElementType());
1880  auto targetDescTy = typeConverter->convertType(viewMemRefType);
1881  if (!targetDescTy || !targetElementTy ||
1882  !LLVM::isCompatibleType(targetElementTy) ||
1883  !LLVM::isCompatibleType(targetDescTy))
1884  return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1885  failure();
1886 
1887  int64_t offset;
1888  SmallVector<int64_t, 4> strides;
1889  auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1890  if (failed(successStrides))
1891  return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1892  assert(offset == 0 && "expected offset to be 0");
1893 
1894  // Target memref must be contiguous in memory (innermost stride is 1), or
1895  // empty (special case when at least one of the memref dimensions is 0).
1896  if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1897  return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1898  failure();
1899 
1900  // Create the descriptor.
1901  MemRefDescriptor sourceMemRef(adaptor.getSource());
1902  auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1903 
1904  // Field 1: Copy the allocated pointer, used for malloc/free.
1905  Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1906  auto srcMemRefType = viewOp.getSource().getType().cast<MemRefType>();
1907  Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1908  loc,
1909  LLVM::LLVMPointerType::get(targetElementTy,
1910  srcMemRefType.getMemorySpaceAsInt()),
1911  allocatedPtr);
1912  targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1913 
1914  // Field 2: Copy the actual aligned pointer to payload.
1915  Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1916  alignedPtr = rewriter.create<LLVM::GEPOp>(
1917  loc, alignedPtr.getType(), alignedPtr, adaptor.getByteShift());
1918  bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1919  loc,
1920  LLVM::LLVMPointerType::get(targetElementTy,
1921  srcMemRefType.getMemorySpaceAsInt()),
1922  alignedPtr);
1923  targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1924 
1925  // Field 3: The offset in the resulting type must be 0. This is because of
1926  // the type change: an offset on srcType* may not be expressible as an
1927  // offset on dstType*.
1928  targetMemRef.setOffset(rewriter, loc,
1929  createIndexConstant(rewriter, loc, offset));
1930 
1931  // Early exit for 0-D corner case.
1932  if (viewMemRefType.getRank() == 0)
1933  return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1934 
1935  // Fields 4 and 5: Update sizes and strides.
1936  Value stride = nullptr, nextSize = nullptr;
1937  for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1938  // Update size.
1939  Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1940  adaptor.getSizes(), i);
1941  targetMemRef.setSize(rewriter, loc, i, size);
1942  // Update stride.
1943  stride = getStride(rewriter, loc, strides, nextSize, stride, i);
1944  targetMemRef.setStride(rewriter, loc, i, stride);
1945  nextSize = size;
1946  }
1947 
1948  rewriter.replaceOp(viewOp, {targetMemRef});
1949  return success();
1950  }
1951 };
1952 
1953 //===----------------------------------------------------------------------===//
1954 // AtomicRMWOpLowering
1955 //===----------------------------------------------------------------------===//
1956 
1957 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1958 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1960 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1961  switch (atomicOp.getKind()) {
1962  case arith::AtomicRMWKind::addf:
1963  return LLVM::AtomicBinOp::fadd;
1964  case arith::AtomicRMWKind::addi:
1965  return LLVM::AtomicBinOp::add;
1966  case arith::AtomicRMWKind::assign:
1967  return LLVM::AtomicBinOp::xchg;
1968  case arith::AtomicRMWKind::maxs:
1969  return LLVM::AtomicBinOp::max;
1970  case arith::AtomicRMWKind::maxu:
1971  return LLVM::AtomicBinOp::umax;
1972  case arith::AtomicRMWKind::mins:
1973  return LLVM::AtomicBinOp::min;
1974  case arith::AtomicRMWKind::minu:
1975  return LLVM::AtomicBinOp::umin;
1976  case arith::AtomicRMWKind::ori:
1977  return LLVM::AtomicBinOp::_or;
1978  case arith::AtomicRMWKind::andi:
1979  return LLVM::AtomicBinOp::_and;
1980  default:
1981  return llvm::None;
1982  }
1983  llvm_unreachable("Invalid AtomicRMWKind");
1984 }
1985 
1986 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1987  using Base::Base;
1988 
1990  matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1991  ConversionPatternRewriter &rewriter) const override {
1992  if (failed(match(atomicOp)))
1993  return failure();
1994  auto maybeKind = matchSimpleAtomicOp(atomicOp);
1995  if (!maybeKind)
1996  return failure();
1997  auto resultType = adaptor.getValue().getType();
1998  auto memRefType = atomicOp.getMemRefType();
1999  auto dataPtr =
2000  getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
2001  adaptor.getIndices(), rewriter);
2002  rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
2003  atomicOp, resultType, *maybeKind, dataPtr, adaptor.getValue(),
2004  LLVM::AtomicOrdering::acq_rel);
2005  return success();
2006  }
2007 };
2008 
2009 } // namespace
2010 
2012  RewritePatternSet &patterns) {
2013  // clang-format off
2014  patterns.add<
2015  AllocaOpLowering,
2016  AllocaScopeOpLowering,
2017  AtomicRMWOpLowering,
2018  AssumeAlignmentOpLowering,
2019  DimOpLowering,
2020  GenericAtomicRMWOpLowering,
2021  GlobalMemrefOpLowering,
2022  GetGlobalMemrefOpLowering,
2023  LoadOpLowering,
2024  MemRefCastOpLowering,
2025  MemRefCopyOpLowering,
2026  MemRefReinterpretCastOpLowering,
2027  MemRefReshapeOpLowering,
2028  PrefetchOpLowering,
2029  RankOpLowering,
2030  ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
2031  ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
2032  StoreOpLowering,
2033  SubViewOpLowering,
2035  ViewOpLowering>(converter);
2036  // clang-format on
2037  auto allocLowering = converter.getOptions().allocLowering;
2039  patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
2040  else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
2041  patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
2042 }
2043 
2044 namespace {
2045 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> {
2046  MemRefToLLVMPass() = default;
2047 
2048  void runOnOperation() override {
2049  Operation *op = getOperation();
2050  const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
2051  LowerToLLVMOptions options(&getContext(),
2052  dataLayoutAnalysis.getAtOrAbove(op));
2053  options.allocLowering =
2056 
2057  options.useGenericFunctions = useGenericFunctions;
2058 
2059  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
2060  options.overrideIndexBitwidth(indexBitwidth);
2061 
2062  LLVMTypeConverter typeConverter(&getContext(), options,
2063  &dataLayoutAnalysis);
2064  RewritePatternSet patterns(&getContext());
2065  populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
2066  LLVMConversionTarget target(getContext());
2067  target.addLegalOp<func::FuncOp>();
2068  if (failed(applyPartialConversion(op, target, std::move(patterns))))
2069  signalPassFailure();
2070  }
2071 };
2072 } // namespace
2073 
2074 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() {
2075  return std::make_unique<MemRefToLLVMPass>();
2076 }
Include the generated interface declarations.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:133
static Value offset(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, Type elemPtrPtrType)
Builds IR extracting the offset from the descriptor.
MLIRContext * getContext() const
Definition: Builders.h:54
void addLegalOp(OperationName op)
Register the given operations as legal.
U cast() const
Definition: Attributes.h:135
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
const LowerToLLVMOptions & getOptions() const
Definition: TypeConverter.h:81
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:388
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
Operation & back()
Definition: Block.h:143
static void setAlignedPtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, Type elemPtrPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:429
void setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr)
Builds IR inserting the allocated pointer into the descriptor.
std::unique_ptr< Pass > createMemRefToLLVMPass()
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:295
static LLVMArrayType get(Type elementType, unsigned numElements)
Gets or creates an instance of LLVM dialect array type containing numElements of elementType, in the same context as elementType.
Definition: LLVMTypes.cpp:39
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
Block represents an ordered list of Operations.
Definition: Block.h:29
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:344
This class represents a single result from folding an operation.
Definition: OpDefinition.h:239
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:492
Value getOperand(unsigned idx)
Definition: Operation.h:267
LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(ModuleOp moduleOp)
BlockArgument insertArgument(args_iterator it, Type type, Location loc)
Insert one value to the position in the argument list indicated by the given iterator.
Definition: Block.cpp:175
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:685
BlockListType::iterator iterator
Definition: Region.h:52
static Value sizeBasePtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrPtrType)
Builds IR extracting the pointer to the first element of the size array.
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
Derived class that automatically populates legalization information for different LLVM ops...
static Value alignedPtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, Type elemPtrPtrType)
Builds IR extracting the aligned pointer from the descriptor.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Operation & front()
Definition: Block.h:144
T lookup(T from) const
Lookup a mapped value within the map.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:200
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
Value offset(OpBuilder &builder, Location loc)
Builds IR extracting the offset from the descriptor.
BlockArgument getArgument(unsigned i)
Definition: Block.h:120
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
Region * getParent() const
Provide a &#39;getParent&#39; method for ilist_node_with_parent methods.
Definition: Block.cpp:26
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
void setAlignedPtr(OpBuilder &builder, Location loc, Value ptr)
Builds IR inserting the aligned pointer into the descriptor.
unsigned getTypeSize(Type t) const
Returns the size of the given type in the current scope.
void map(Block *from, Block *to)
Inserts a new mapping for &#39;from&#39; to &#39;to&#39;.
void populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:172
static llvm::DenseMap< int64_t, int64_t > getExpandedDimToCollapsedDimMap(ArrayRef< AffineMap > reassociation)
Compute a map that for a given dimension of the expanded type gives the dimension in the collapsed ty...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
static void setStride(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:854
OpListType::iterator iterator
Definition: Block.h:131
void mergeBlocks(Block *source, Block *dest, ValueRange argValues) override
PatternRewriter hook for merging a block into another.
void setConstantStride(OpBuilder &builder, Location loc, unsigned pos, uint64_t stride)
U dyn_cast() const
Definition: Types.h:270
Lowering for AllocOp and AllocaOp.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:99
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static SmallVector< Value > getAsValues(OpBuilder &b, Location loc, ArrayRef< OpFoldResult > valueOrAttrVec)
Helper function to convert a vector of OpFoldResults into a vector of Values.
Attributes are known-constant values of operations.
Definition: Attributes.h:24
static void setSize(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:233
static unsigned getMemRefEltSizeInBytes(MemRefType memRefType)
Definition: Utils.cpp:602
static llvm::Value * getSizeInBytes(llvm::IRBuilderBase &builder, llvm::Value *basePtr)
Computes the size of type in bytes.
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
Use malloc for for heap allocations.
Use aligned_alloc for heap allocations.
static Value strideBasePtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
static LLVMPointerType get(MLIRContext *context, unsigned addressSpace=0)
Gets or creates an instance of LLVM dialect pointer type pointing to an object of pointee type in the...
Definition: LLVMTypes.cpp:198
static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, Type elemPtrPtrType)
TODO: The following accessors don&#39;t take alignment rules between elements of the descriptor struct in...
BlockArgListType getArguments()
Definition: Block.h:76
static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
Definition: SPIRVOps.cpp:974
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(ModuleOp moduleOp, Type indexType)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:62
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(ModuleOp moduleOp, Type indexType)
IntegerType getI64Type()
Definition: Builders.cpp:56
static llvm::ManagedStatic< PassManagerOptions > options
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:230
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:377
LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType, Type unrankedDescriptorType)
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:294
Type getIndexType()
Returns the type of array element in this descriptor.
Definition: MemRefBuilder.h:83
void setOffset(OpBuilder &builder, Location loc, Value offset)
Builds IR inserting the offset into the descriptor.
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp, Type indexType)
LLVM::LLVMFuncOp lookupOrCreateMallocFn(ModuleOp moduleOp, Type indexType)
Type getType() const
Return the type of this value.
Definition: Value.h:118
IndexType getIndexType()
Definition: Builders.cpp:48
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:112
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
Type conversion class.
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
Definition: LLVMTypes.h:283
Type getElementType() const
Returns the element type of this memref type.
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:30
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:87
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs)
void setRank(OpBuilder &builder, Location loc, Value value)
Builds IR setting the rank in the descriptor.
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
Options to control the LLVM lowering.
This class implements a pattern rewriter for use with ConversionPatterns.
static void setOffset(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, Type elemPtrPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:382
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride)
Builds IR inserting the pos-th stride into the descriptor.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with &#39;argTypes&#39; arguments and set the insertion point to the end of it...
Definition: Builders.cpp:377
void setSize(OpBuilder &builder, Location loc, unsigned pos, Value size)
Builds IR inserting the pos-th size into the descriptor.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, Type elemPtrPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:391
SmallVector< int64_t, 4 > extractFromI64ArrayAttr(Attribute attr)
Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
static Value pack(OpBuilder &builder, Location loc, LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
bool isa() const
Definition: Types.h:254
Stores data layout objects for each operation that specifies the data layout above and below the give...
result_range getResults()
Definition: Operation.h:332
This class helps build Operations.
Definition: Builders.h:192
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:95
static void computeSizes(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, ArrayRef< UnrankedMemRefDescriptor > values, SmallVectorImpl< Value > &sizes)
Builds IR computing the sizes in bytes (suitable for opaque allocation) and appends the corresponding...
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
bool isStaticShapeAndContiguousRowMajor(MemRefType memrefType)
Helper determining if a memref is static-shape and contiguous-row-major layout, while still allowing ...
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp)
U cast() const
Definition: Types.h:278
The main mechanism for performing data layout queries.