MLIR 23.0.0git
MemRefToLLVM.cpp
Go to the documentation of this file.
1//===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
23#include "mlir/IR/AffineMap.h"
25#include "mlir/IR/IRMapping.h"
26#include "mlir/Pass/Pass.h"
27#include "llvm/Support/DebugLog.h"
28#include "llvm/Support/MathExtras.h"
29
30#include <optional>
31
32#define DEBUG_TYPE "memref-to-llvm"
33
34namespace mlir {
35#define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
36#include "mlir/Conversion/Passes.h.inc"
37} // namespace mlir
38
39using namespace mlir;
40
41static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags =
42 LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw;
43
44namespace {
45
46static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
47 return ShapedType::isStatic(strideOrOffset);
48}
49
50static FailureOr<LLVM::LLVMFuncOp>
51getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
52 Operation *module, SymbolTableCollection *symbolTables) {
53 bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
54
55 if (useGenericFn)
56 return LLVM::lookupOrCreateGenericFreeFn(b, module, symbolTables);
57
58 return LLVM::lookupOrCreateFreeFn(b, module, symbolTables);
59}
60
61static FailureOr<LLVM::LLVMFuncOp>
62getNotalignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
63 Operation *module, Type indexType,
64 SymbolTableCollection *symbolTables) {
65 bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
66 if (useGenericFn)
67 return LLVM::lookupOrCreateGenericAllocFn(b, module, indexType,
68 symbolTables);
69
70 return LLVM::lookupOrCreateMallocFn(b, module, indexType, symbolTables);
71}
72
73static FailureOr<LLVM::LLVMFuncOp>
74getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
75 Operation *module, Type indexType,
76 SymbolTableCollection *symbolTables) {
77 bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
78
79 if (useGenericFn)
80 return LLVM::lookupOrCreateGenericAlignedAllocFn(b, module, indexType,
81 symbolTables);
82
83 return LLVM::lookupOrCreateAlignedAllocFn(b, module, indexType, symbolTables);
84}
85
86/// Computes the aligned value for 'input' as follows:
87/// bumped = input + alignement - 1
88/// aligned = bumped - bumped % alignment
89static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
90 Value input, Value alignment) {
91 Value one = LLVM::ConstantOp::create(rewriter, loc, alignment.getType(),
92 rewriter.getIndexAttr(1));
93 Value bump = LLVM::SubOp::create(rewriter, loc, alignment, one);
94 Value bumped = LLVM::AddOp::create(rewriter, loc, input, bump);
95 Value mod = LLVM::URemOp::create(rewriter, loc, bumped, alignment);
96 return LLVM::SubOp::create(rewriter, loc, bumped, mod);
97}
98
99/// Computes the byte size for the MemRef element type.
100static unsigned getMemRefEltSizeInBytes(const LLVMTypeConverter *typeConverter,
101 MemRefType memRefType, Operation *op,
102 const DataLayout *defaultLayout) {
103 const DataLayout *layout = defaultLayout;
104 if (const DataLayoutAnalysis *analysis =
105 typeConverter->getDataLayoutAnalysis()) {
106 layout = &analysis->getAbove(op);
107 }
108 Type elementType = memRefType.getElementType();
109 if (auto memRefElementType = dyn_cast<MemRefType>(elementType))
110 return typeConverter->getMemRefDescriptorSize(memRefElementType, *layout);
111 if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
112 return typeConverter->getUnrankedMemRefDescriptorSize(memRefElementType,
113 *layout);
114 return layout->getTypeSize(elementType);
115}
116
117static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
118 Location loc, Value allocatedPtr,
119 MemRefType memRefType, Type elementPtrType,
120 const LLVMTypeConverter &typeConverter) {
121 auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType());
122 FailureOr<unsigned> maybeMemrefAddrSpace =
123 typeConverter.getMemRefAddressSpace(memRefType);
124 assert(succeeded(maybeMemrefAddrSpace) && "unsupported address space");
125 unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
126 if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
127 allocatedPtr = LLVM::AddrSpaceCastOp::create(
128 rewriter, loc,
129 LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
130 allocatedPtr);
131 return allocatedPtr;
132}
133
134class AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
135 SymbolTableCollection *symbolTables = nullptr;
136
137public:
138 explicit AllocOpLowering(const LLVMTypeConverter &typeConverter,
139 SymbolTableCollection *symbolTables = nullptr,
140 PatternBenefit benefit = 1)
141 : ConvertOpToLLVMPattern<memref::AllocOp>(typeConverter, benefit),
142 symbolTables(symbolTables) {}
143
144 LogicalResult
145 matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
146 ConversionPatternRewriter &rewriter) const override {
147 auto loc = op.getLoc();
148 MemRefType memRefType = op.getType();
149 if (!isConvertibleAndHasIdentityMaps(memRefType))
150 return rewriter.notifyMatchFailure(op, "incompatible memref type");
151
152 // Get or insert alloc function into the module.
153 FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
154 getNotalignedAllocFn(rewriter, getTypeConverter(),
155 op->getParentWithTrait<OpTrait::SymbolTable>(),
156 getIndexType(), symbolTables);
157 if (failed(allocFuncOp))
158 return failure();
159
160 // Get actual sizes of the memref as values: static sizes are constant
161 // values and dynamic sizes are passed to 'alloc' as operands. In case of
162 // zero-dimensional memref, assume a scalar (size 1).
164 SmallVector<Value, 4> strides;
165 Value sizeBytes;
166
167 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
168 rewriter, sizes, strides, sizeBytes, true);
169
170 Value alignment = getAlignment(rewriter, loc, op);
171 if (alignment) {
172 // Adjust the allocation size to consider alignment.
173 sizeBytes = LLVM::AddOp::create(rewriter, loc, sizeBytes, alignment);
174 }
175
176 // Allocate the underlying buffer.
177 Type elementPtrType = this->getElementPtrType(memRefType);
178 assert(elementPtrType && "could not compute element ptr type");
179 auto results =
180 LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(), sizeBytes);
181
182 Value allocatedPtr =
183 castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
184 elementPtrType, *getTypeConverter());
185 Value alignedPtr = allocatedPtr;
186 if (alignment) {
187 // Compute the aligned pointer.
188 Value allocatedInt =
189 LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), allocatedPtr);
190 Value alignmentInt =
191 createAligned(rewriter, loc, allocatedInt, alignment);
192 alignedPtr =
193 LLVM::IntToPtrOp::create(rewriter, loc, elementPtrType, alignmentInt);
194 }
195
196 // Create the MemRef descriptor.
197 auto memRefDescriptor = this->createMemRefDescriptor(
198 loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
199
200 // Return the final value of the descriptor.
201 rewriter.replaceOp(op, {memRefDescriptor});
202 return success();
203 }
204
205 /// Computes the alignment for the given memory allocation op.
206 template <typename OpType>
207 Value getAlignment(ConversionPatternRewriter &rewriter, Location loc,
208 OpType op) const {
209 MemRefType memRefType = op.getType();
210 Value alignment;
211 if (auto alignmentAttr = op.getAlignment()) {
212 Type indexType = getIndexType();
213 alignment =
214 createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr);
215 } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
216 // In the case where no alignment is specified, we may want to override
217 // `malloc's` behavior. `malloc` typically aligns at the size of the
218 // biggest scalar on a target HW. For non-scalars, use the natural
219 // alignment of the LLVM type given by the LLVM DataLayout.
220 alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
221 }
222 return alignment;
223 }
224};
225
226class AlignedAllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
227 SymbolTableCollection *symbolTables = nullptr;
228
229public:
230 explicit AlignedAllocOpLowering(const LLVMTypeConverter &typeConverter,
231 SymbolTableCollection *symbolTables = nullptr,
232 PatternBenefit benefit = 1)
233 : ConvertOpToLLVMPattern<memref::AllocOp>(typeConverter, benefit),
234 symbolTables(symbolTables) {}
235
236 LogicalResult
237 matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
238 ConversionPatternRewriter &rewriter) const override {
239 auto loc = op.getLoc();
240 MemRefType memRefType = op.getType();
241 if (!isConvertibleAndHasIdentityMaps(memRefType))
242 return rewriter.notifyMatchFailure(op, "incompatible memref type");
243
244 // Get or insert alloc function into module.
245 FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
246 getAlignedAllocFn(rewriter, getTypeConverter(),
247 op->getParentWithTrait<OpTrait::SymbolTable>(),
248 getIndexType(), symbolTables);
249 if (failed(allocFuncOp))
250 return failure();
251
252 // Get actual sizes of the memref as values: static sizes are constant
253 // values and dynamic sizes are passed to 'alloc' as operands. In case of
254 // zero-dimensional memref, assume a scalar (size 1).
256 SmallVector<Value, 4> strides;
257 Value sizeBytes;
258
259 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
260 rewriter, sizes, strides, sizeBytes, !false);
261
262 int64_t alignment = alignedAllocationGetAlignment(op, &defaultLayout);
263
264 Value allocAlignment =
265 createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
266
267 // Function aligned_alloc requires size to be a multiple of alignment; we
268 // pad the size to the next multiple if necessary.
269 if (!isMemRefSizeMultipleOf(memRefType, alignment, op, &defaultLayout))
270 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
271
272 Type elementPtrType = this->getElementPtrType(memRefType);
273 auto results =
274 LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(),
275 ValueRange({allocAlignment, sizeBytes}));
276
277 Value ptr =
278 castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
279 elementPtrType, *getTypeConverter());
280
281 // Create the MemRef descriptor.
282 auto memRefDescriptor = this->createMemRefDescriptor(
283 loc, memRefType, ptr, ptr, sizes, strides, rewriter);
284
285 // Return the final value of the descriptor.
286 rewriter.replaceOp(op, {memRefDescriptor});
287 return success();
288 }
289
290 /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
291 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
292
293 /// Computes the alignment for aligned_alloc used to allocate the buffer for
294 /// the memory allocation op.
295 ///
296 /// Aligned_alloc requires the allocation size to be a power of two, and the
297 /// allocation size to be a multiple of the alignment.
298 int64_t alignedAllocationGetAlignment(memref::AllocOp op,
299 const DataLayout *defaultLayout) const {
300 if (std::optional<uint64_t> alignment = op.getAlignment())
301 return *alignment;
302
303 // Whenever we don't have alignment set, we will use an alignment
304 // consistent with the element type; since the allocation size has to be a
305 // power of two, we will bump to the next power of two if it isn't.
306 unsigned eltSizeBytes = getMemRefEltSizeInBytes(
307 getTypeConverter(), op.getType(), op, defaultLayout);
308 return std::max(kMinAlignedAllocAlignment,
309 llvm::PowerOf2Ceil(eltSizeBytes));
310 }
311
312 /// Returns true if the memref size in bytes is known to be a multiple of
313 /// factor.
314 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, Operation *op,
315 const DataLayout *defaultLayout) const {
316 uint64_t sizeDivisor =
317 getMemRefEltSizeInBytes(getTypeConverter(), type, op, defaultLayout);
318 for (unsigned i = 0, e = type.getRank(); i < e; i++) {
319 if (type.isDynamicDim(i))
320 continue;
321 sizeDivisor = sizeDivisor * type.getDimSize(i);
322 }
323 return sizeDivisor % factor == 0;
324 }
325
326private:
327 /// Default layout to use in absence of the corresponding analysis.
328 DataLayout defaultLayout;
329};
330
331struct AllocaOpLowering : public ConvertOpToLLVMPattern<memref::AllocaOp> {
333
334 /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
335 /// is set to null for stack allocations. `accessAlignment` is set if
336 /// alignment is needed post allocation (for eg. in conjunction with malloc).
337 LogicalResult
338 matchAndRewrite(memref::AllocaOp op, OpAdaptor adaptor,
339 ConversionPatternRewriter &rewriter) const override {
340 auto loc = op.getLoc();
341 MemRefType memRefType = op.getType();
342 if (!isConvertibleAndHasIdentityMaps(memRefType))
343 return rewriter.notifyMatchFailure(op, "incompatible memref type");
344
345 // Get actual sizes of the memref as values: static sizes are constant
346 // values and dynamic sizes are passed to 'alloc' as operands. In case of
347 // zero-dimensional memref, assume a scalar (size 1).
349 SmallVector<Value, 4> strides;
350 Value size;
351
352 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
353 rewriter, sizes, strides, size, !true);
354
355 // With alloca, one gets a pointer to the element type right away.
356 // For stack allocations.
357 auto elementType =
358 typeConverter->convertType(op.getType().getElementType());
359 FailureOr<unsigned> maybeAddressSpace =
360 getTypeConverter()->getMemRefAddressSpace(op.getType());
361 assert(succeeded(maybeAddressSpace) && "unsupported address space");
362 unsigned addrSpace = *maybeAddressSpace;
363 auto elementPtrType =
364 LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
365
366 auto allocatedElementPtr =
367 LLVM::AllocaOp::create(rewriter, loc, elementPtrType, elementType, size,
368 op.getAlignment().value_or(0));
369
370 // Create the MemRef descriptor.
371 auto memRefDescriptor = this->createMemRefDescriptor(
372 loc, memRefType, allocatedElementPtr, allocatedElementPtr, sizes,
373 strides, rewriter);
374
375 // Return the final value of the descriptor.
376 rewriter.replaceOp(op, {memRefDescriptor});
377 return success();
378 }
379};
380
381struct AllocaScopeOpLowering
382 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
383 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
384
385 LogicalResult
386 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
387 ConversionPatternRewriter &rewriter) const override {
388 OpBuilder::InsertionGuard guard(rewriter);
389 Location loc = allocaScopeOp.getLoc();
390
391 // Split the current block before the AllocaScopeOp to create the inlining
392 // point.
393 auto *currentBlock = rewriter.getInsertionBlock();
394 auto *remainingOpsBlock =
395 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
396 Block *continueBlock;
397 if (allocaScopeOp.getNumResults() == 0) {
398 continueBlock = remainingOpsBlock;
399 } else {
400 continueBlock = rewriter.createBlock(
401 remainingOpsBlock, allocaScopeOp.getResultTypes(),
402 SmallVector<Location>(allocaScopeOp->getNumResults(),
403 allocaScopeOp.getLoc()));
404 LLVM::BrOp::create(rewriter, loc, ValueRange(), remainingOpsBlock);
405 }
406
407 // Inline body region.
408 Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
409 Block *afterBody = &allocaScopeOp.getBodyRegion().back();
410 rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
411
412 // Save stack and then branch into the body of the region.
413 rewriter.setInsertionPointToEnd(currentBlock);
414 auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType());
415 LLVM::BrOp::create(rewriter, loc, ValueRange(), beforeBody);
416
417 // Replace the alloca_scope return with a branch that jumps out of the body.
418 // Stack restore before leaving the body region.
419 rewriter.setInsertionPointToEnd(afterBody);
420 auto returnOp =
421 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
422 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
423 returnOp, returnOp.getResults(), continueBlock);
424
425 // Insert stack restore before jumping out the body of the region.
426 rewriter.setInsertionPoint(branchOp);
427 LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp);
428
429 // Replace the op with values return from the body region.
430 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
431
432 return success();
433 }
434};
435
436struct AssumeAlignmentOpLowering
437 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
439 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
440 explicit AssumeAlignmentOpLowering(const LLVMTypeConverter &converter)
442
443 LogicalResult
444 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
445 ConversionPatternRewriter &rewriter) const override {
446 Value memref = adaptor.getMemref();
447 unsigned alignment = op.getAlignment();
448 auto loc = op.getLoc();
449
450 auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
451 Value ptr = getStridedElementPtr(rewriter, loc, srcMemRefType, memref,
452 /*indices=*/{});
453
454 // Emit llvm.assume(true) ["align"(memref, alignment)].
455 // This is more direct than ptrtoint-based checks, is explicitly supported,
456 // and works with non-integral address spaces.
457 Value trueCond =
458 LLVM::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(true));
459 Value alignmentConst =
460 createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
461 LLVM::AssumeOp::create(rewriter, loc, trueCond, LLVM::AssumeAlignTag(), ptr,
462 alignmentConst);
463 rewriter.replaceOp(op, memref);
464 return success();
465 }
466};
467
468struct DistinctObjectsOpLowering
469 : public ConvertOpToLLVMPattern<memref::DistinctObjectsOp> {
471 memref::DistinctObjectsOp>::ConvertOpToLLVMPattern;
472 explicit DistinctObjectsOpLowering(const LLVMTypeConverter &converter)
474
475 LogicalResult
476 matchAndRewrite(memref::DistinctObjectsOp op, OpAdaptor adaptor,
477 ConversionPatternRewriter &rewriter) const override {
478 ValueRange operands = adaptor.getOperands();
479 if (operands.size() <= 1) {
480 // Fast path.
481 rewriter.replaceOp(op, operands);
482 return success();
483 }
484
485 Location loc = op.getLoc();
487 for (auto [origOperand, newOperand] :
488 llvm::zip_equal(op.getOperands(), operands)) {
489 auto memrefType = cast<MemRefType>(origOperand.getType());
490 MemRefDescriptor memRefDescriptor(newOperand);
491 Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
492 memrefType);
493 ptrs.push_back(ptr);
494 }
495
496 auto cond =
497 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), 1);
498 // Generate separate_storage assumptions for each pair of pointers.
499 for (auto i : llvm::seq<size_t>(ptrs.size() - 1)) {
500 for (auto j : llvm::seq<size_t>(i + 1, ptrs.size())) {
501 Value ptr1 = ptrs[i];
502 Value ptr2 = ptrs[j];
503 LLVM::AssumeOp::create(rewriter, loc, cond,
504 LLVM::AssumeSeparateStorageTag{}, ptr1, ptr2);
505 }
506 }
507
508 rewriter.replaceOp(op, operands);
509 return success();
510 }
511};
512
513// A `dealloc` is converted into a call to `free` on the underlying data buffer.
514// The memref descriptor being an SSA value, there is no need to clean it up
515// in any way.
516class DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
517 SymbolTableCollection *symbolTables = nullptr;
518
519public:
520 explicit DeallocOpLowering(const LLVMTypeConverter &typeConverter,
521 SymbolTableCollection *symbolTables = nullptr,
522 PatternBenefit benefit = 1)
523 : ConvertOpToLLVMPattern<memref::DeallocOp>(typeConverter, benefit),
524 symbolTables(symbolTables) {}
525
526 LogicalResult
527 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
528 ConversionPatternRewriter &rewriter) const override {
529 // Insert the `free` declaration if it is not already present.
530 FailureOr<LLVM::LLVMFuncOp> freeFunc =
531 getFreeFn(rewriter, getTypeConverter(),
532 op->getParentWithTrait<OpTrait::SymbolTable>(), symbolTables);
533 if (failed(freeFunc))
534 return failure();
535 Value allocatedPtr;
536 if (auto unrankedTy =
537 llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
538 auto elementPtrTy = LLVM::LLVMPointerType::get(
539 rewriter.getContext(), unrankedTy.getMemorySpaceAsInt());
541 rewriter, op.getLoc(),
542 UnrankedMemRefDescriptor(adaptor.getMemref())
543 .memRefDescPtr(rewriter, op.getLoc()),
544 elementPtrTy);
545 } else {
546 allocatedPtr = MemRefDescriptor(adaptor.getMemref())
547 .allocatedPtr(rewriter, op.getLoc());
548 }
549 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc.value(),
550 allocatedPtr);
551 return success();
552 }
553};
554
555// A `dim` is converted to a constant for static sizes and to an access to the
556// size stored in the memref descriptor for dynamic sizes.
557struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
559
560 LogicalResult
561 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
562 ConversionPatternRewriter &rewriter) const override {
563 Type operandType = dimOp.getSource().getType();
564 if (isa<UnrankedMemRefType>(operandType)) {
565 FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
566 operandType, dimOp, adaptor.getOperands(), rewriter);
567 if (failed(extractedSize))
568 return failure();
569 rewriter.replaceOp(dimOp, {*extractedSize});
570 return success();
571 }
572 if (isa<MemRefType>(operandType)) {
573 rewriter.replaceOp(
574 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
575 adaptor.getOperands(), rewriter)});
576 return success();
577 }
578 llvm_unreachable("expected MemRefType or UnrankedMemRefType");
579 }
580
581private:
582 FailureOr<Value>
583 extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
584 OpAdaptor adaptor,
585 ConversionPatternRewriter &rewriter) const {
586 Location loc = dimOp.getLoc();
587
588 auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
589 auto scalarMemRefType =
590 MemRefType::get({}, unrankedMemRefType.getElementType());
591 FailureOr<unsigned> maybeAddressSpace =
592 getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
593 if (failed(maybeAddressSpace)) {
594 dimOp.emitOpError("memref memory space must be convertible to an integer "
595 "address space");
596 return failure();
597 }
598 unsigned addressSpace = *maybeAddressSpace;
599
600 // Extract pointer to the underlying ranked descriptor and bitcast it to a
601 // memref<element_type> descriptor pointer to minimize the number of GEP
602 // operations.
603 UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
604 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
605
606 Type elementType = typeConverter->convertType(scalarMemRefType);
607
608 // Get pointer to offset field of memref<element_type> descriptor.
609 auto indexPtrTy =
610 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
611 Value offsetPtr =
612 LLVM::GEPOp::create(rewriter, loc, indexPtrTy, elementType,
613 underlyingRankedDesc, ArrayRef<LLVM::GEPArg>{0, 2});
614
615 // The size value that we have to extract can be obtained using GEPop with
616 // `dimOp.index() + 1` index argument.
617 Value idxPlusOne = LLVM::AddOp::create(
618 rewriter, loc,
619 createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
620 adaptor.getIndex());
621 Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy,
622 getTypeConverter()->getIndexType(),
623 offsetPtr, idxPlusOne);
624 return LLVM::LoadOp::create(rewriter, loc,
625 getTypeConverter()->getIndexType(), sizePtr)
626 .getResult();
627 }
628
629 std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
630 if (auto idx = dimOp.getConstantIndex())
631 return idx;
632
633 if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
634 return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
635
636 return std::nullopt;
637 }
638
639 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
640 OpAdaptor adaptor,
641 ConversionPatternRewriter &rewriter) const {
642 Location loc = dimOp.getLoc();
643
644 // Take advantage if index is constant.
645 MemRefType memRefType = cast<MemRefType>(operandType);
646 Type indexType = getIndexType();
647 if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
648 int64_t i = *index;
649 if (i >= 0 && i < memRefType.getRank()) {
650 if (memRefType.isDynamicDim(i)) {
651 // extract dynamic size from the memref descriptor.
652 MemRefDescriptor descriptor(adaptor.getSource());
653 return descriptor.size(rewriter, loc, i);
654 }
655 // Use constant for static size.
656 int64_t dimSize = memRefType.getDimSize(i);
657 return createIndexAttrConstant(rewriter, loc, indexType, dimSize);
658 }
659 }
660 Value index = adaptor.getIndex();
661 int64_t rank = memRefType.getRank();
662 MemRefDescriptor memrefDescriptor(adaptor.getSource());
663 return memrefDescriptor.size(rewriter, loc, index, rank);
664 }
665};
666
667/// Common base for load and store operations on MemRefs. Restricts the match
668/// to supported MemRef types. Provides functionality to emit code accessing a
669/// specific element of the underlying data buffer.
670template <typename Derived>
671struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
673 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
674 using Base = LoadStoreOpLowering<Derived>;
675};
676
677/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
678/// retried until it succeeds in atomically storing a new value into memory.
679///
680/// +---------------------------------+
681/// | <code before the AtomicRMWOp> |
682/// | <compute initial %loaded> |
683/// | cf.br loop(%loaded) |
684/// +---------------------------------+
685/// |
686/// -------| |
687/// | v v
688/// | +--------------------------------+
689/// | | loop(%loaded): |
690/// | | <body contents> |
691/// | | %pair = cmpxchg |
692/// | | %ok = %pair[0] |
693/// | | %new = %pair[1] |
694/// | | cf.cond_br %ok, end, loop(%new) |
695/// | +--------------------------------+
696/// | | |
697/// |----------- |
698/// v
699/// +--------------------------------+
700/// | end: |
701/// | <code after the AtomicRMWOp> |
702/// +--------------------------------+
703///
704struct GenericAtomicRMWOpLowering
705 : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
706 using Base::Base;
707
708 LogicalResult
709 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
710 ConversionPatternRewriter &rewriter) const override {
711 auto loc = atomicOp.getLoc();
712 Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
713
714 // Split the block into initial, loop, and ending parts.
715 auto *initBlock = rewriter.getInsertionBlock();
716 auto *loopBlock = rewriter.splitBlock(initBlock, Block::iterator(atomicOp));
717 loopBlock->addArgument(valueType, loc);
718
719 auto *endBlock =
720 rewriter.splitBlock(loopBlock, Block::iterator(atomicOp)++);
721
722 // Compute the loaded value and branch to the loop block.
723 rewriter.setInsertionPointToEnd(initBlock);
724 auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
725 auto dataPtr = getStridedElementPtr(
726 rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
727 Value init = LLVM::LoadOp::create(
728 rewriter, loc, typeConverter->convertType(memRefType.getElementType()),
729 dataPtr);
730 LLVM::BrOp::create(rewriter, loc, init, loopBlock);
731
732 // Prepare the body of the loop block.
733 rewriter.setInsertionPointToStart(loopBlock);
734
735 // Clone the GenericAtomicRMWOp region and extract the result.
736 auto loopArgument = loopBlock->getArgument(0);
737 IRMapping mapping;
738 mapping.map(atomicOp.getCurrentValue(), loopArgument);
739 Block &entryBlock = atomicOp.body().front();
740 for (auto &nestedOp : entryBlock.without_terminator()) {
741 Operation *clone = rewriter.clone(nestedOp, mapping);
742 mapping.map(nestedOp.getResults(), clone->getResults());
743 }
744
745 Value result =
746 mapping.lookupOrNull(entryBlock.getTerminator()->getOperand(0));
747 if (!result) {
748 return atomicOp.emitError("result not defined in region");
749 }
750
751 // Prepare the epilog of the loop block.
752 // Append the cmpxchg op to the end of the loop block.
753 auto successOrdering = LLVM::AtomicOrdering::acq_rel;
754 auto failureOrdering = LLVM::AtomicOrdering::monotonic;
755 auto cmpxchg =
756 LLVM::AtomicCmpXchgOp::create(rewriter, loc, dataPtr, loopArgument,
757 result, successOrdering, failureOrdering);
758 // Extract the %new_loaded and %ok values from the pair.
759 Value newLoaded = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 0);
760 Value ok = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 1);
761
762 // Conditionally branch to the end or back to the loop depending on %ok.
763 LLVM::CondBrOp::create(rewriter, loc, ok, endBlock, ArrayRef<Value>(),
764 loopBlock, newLoaded);
765
766 rewriter.setInsertionPointToEnd(endBlock);
767
768 // The 'result' of the atomic_rmw op is the newly loaded value.
769 rewriter.replaceOp(atomicOp, {newLoaded});
770
771 return success();
772 }
773};
774
775/// Returns the LLVM type of the global variable given the memref type `type`.
776static Type
777convertGlobalMemrefTypeToLLVM(MemRefType type,
778 const LLVMTypeConverter &typeConverter) {
779 // LLVM type for a global memref will be a multi-dimension array. For
780 // declarations or uninitialized global memrefs, we can potentially flatten
781 // this to a 1D array. However, for memref.global's with an initial value,
782 // we do not intend to flatten the ElementsAttribute when going from std ->
783 // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
784 Type elementType = typeConverter.convertType(type.getElementType());
785 Type arrayTy = elementType;
786 // Shape has the outermost dim at index 0, so need to walk it backwards
787 for (int64_t dim : llvm::reverse(type.getShape()))
788 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
789 return arrayTy;
790}
791
792/// GlobalMemrefOp is lowered to a LLVM Global Variable.
793class GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<memref::GlobalOp> {
794 SymbolTableCollection *symbolTables = nullptr;
795
796public:
797 explicit GlobalMemrefOpLowering(const LLVMTypeConverter &typeConverter,
798 SymbolTableCollection *symbolTables = nullptr,
799 PatternBenefit benefit = 1)
800 : ConvertOpToLLVMPattern<memref::GlobalOp>(typeConverter, benefit),
801 symbolTables(symbolTables) {}
802
803 LogicalResult
804 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
805 ConversionPatternRewriter &rewriter) const override {
806 MemRefType type = global.getType();
807 if (!isConvertibleAndHasIdentityMaps(type))
808 return failure();
809
810 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
811
812 LLVM::Linkage linkage =
813 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
814 bool isExternal = global.isExternal();
815 bool isUninitialized = global.isUninitialized();
816
817 Attribute initialValue = nullptr;
818 if (!isExternal && !isUninitialized) {
819 auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
820 initialValue = elementsAttr;
821
822 // For scalar memrefs, the global variable created is of the element type,
823 // so unpack the elements attribute to extract the value.
824 if (type.getRank() == 0)
825 initialValue = elementsAttr.getSplatValue<Attribute>();
826 }
827
828 uint64_t alignment = global.getAlignment().value_or(0);
829 FailureOr<unsigned> addressSpace =
830 getTypeConverter()->getMemRefAddressSpace(type);
831 if (failed(addressSpace))
832 return global.emitOpError(
833 "memory space cannot be converted to an integer address space");
834
835 // Remove old operation from symbol table.
836 SymbolTable *symbolTable = nullptr;
837 if (symbolTables) {
838 Operation *symbolTableOp =
840 symbolTable = &symbolTables->getSymbolTable(symbolTableOp);
841 symbolTable->remove(global);
842 }
843
844 // Create new operation.
845 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
846 global, arrayTy, global.getConstant(), linkage, global.getSymName(),
847 initialValue, alignment, *addressSpace);
848
849 // Insert new operation into symbol table.
850 if (symbolTable)
851 symbolTable->insert(newGlobal, rewriter.getInsertionPoint());
852
853 if (!isExternal && isUninitialized) {
854 rewriter.createBlock(&newGlobal.getInitializerRegion());
855 Value undef[] = {
856 LLVM::UndefOp::create(rewriter, newGlobal.getLoc(), arrayTy)};
857 LLVM::ReturnOp::create(rewriter, newGlobal.getLoc(), undef);
858 }
859 return success();
860 }
861};
862
863/// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
864/// the first element stashed into the descriptor. This reuses
865/// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
866struct GetGlobalMemrefOpLowering
867 : public ConvertOpToLLVMPattern<memref::GetGlobalOp> {
868 using ConvertOpToLLVMPattern<memref::GetGlobalOp>::ConvertOpToLLVMPattern;
869
870 /// Buffer "allocation" for memref.get_global op is getting the address of
871 /// the global variable referenced.
872 LogicalResult
873 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
874 ConversionPatternRewriter &rewriter) const override {
875 auto loc = op.getLoc();
876 MemRefType memRefType = op.getType();
877 if (!isConvertibleAndHasIdentityMaps(memRefType))
878 return rewriter.notifyMatchFailure(op, "incompatible memref type");
879
880 // Get actual sizes of the memref as values: static sizes are constant
881 // values and dynamic sizes are passed to 'alloc' as operands. In case of
882 // zero-dimensional memref, assume a scalar (size 1).
884 SmallVector<Value, 4> strides;
885 Value sizeBytes;
886
887 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
888 rewriter, sizes, strides, sizeBytes, !false);
889
890 MemRefType type = cast<MemRefType>(op.getResult().getType());
891
892 // This is called after a type conversion, which would have failed if this
893 // call fails.
894 FailureOr<unsigned> maybeAddressSpace =
895 getTypeConverter()->getMemRefAddressSpace(type);
896 assert(succeeded(maybeAddressSpace) && "unsupported address space");
897 unsigned memSpace = *maybeAddressSpace;
898
899 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
900 auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace);
901 auto addressOf =
902 LLVM::AddressOfOp::create(rewriter, loc, ptrTy, op.getName());
903
904 // Get the address of the first element in the array by creating a GEP with
905 // the address of the GV as the base, and (rank + 1) number of 0 indices.
906 auto gep =
907 LLVM::GEPOp::create(rewriter, loc, ptrTy, arrayTy, addressOf,
908 SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));
909
910 // We do not expect the memref obtained using `memref.get_global` to be
911 // ever deallocated. Set the allocated pointer to be known bad value to
912 // help debug if that ever happens.
913 auto intPtrType = getIntPtrType(memSpace);
914 Value deadBeefConst =
915 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
916 auto deadBeefPtr =
917 LLVM::IntToPtrOp::create(rewriter, loc, ptrTy, deadBeefConst);
918
919 // Both allocated and aligned pointers are same. We could potentially stash
920 // a nullptr for the allocated pointer since we do not expect any dealloc.
921 // Create the MemRef descriptor.
922 auto memRefDescriptor = this->createMemRefDescriptor(
923 loc, memRefType, deadBeefPtr, gep, sizes, strides, rewriter);
924
925 // Return the final value of the descriptor.
926 rewriter.replaceOp(op, {memRefDescriptor});
927 return success();
928 }
929};
930
931// Load operation is lowered to obtaining a pointer to the indexed element
932// and loading it.
933struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
934 using Base::Base;
935
936 LogicalResult
937 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
938 ConversionPatternRewriter &rewriter) const override {
939 auto type = loadOp.getMemRefType();
940
941 // Per memref.load spec, the indices must be in-bounds:
942 // 0 <= idx < dim_size, and additionally all offsets are non-negative,
943 // hence inbounds and nuw are used when lowering to llvm.getelementptr.
944 Value dataPtr = getStridedElementPtr(rewriter, loadOp.getLoc(), type,
945 adaptor.getMemref(),
946 adaptor.getIndices(), kNoWrapFlags);
947 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
948 loadOp, typeConverter->convertType(type.getElementType()), dataPtr,
949 loadOp.getAlignment().value_or(0), false, loadOp.getNontemporal());
950 return success();
951 }
952};
953
954// Store operation is lowered to obtaining a pointer to the indexed element,
955// and storing the given value to it.
956struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
957 using Base::Base;
958
959 LogicalResult
960 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
961 ConversionPatternRewriter &rewriter) const override {
962 auto type = op.getMemRefType();
963
964 // Per memref.store spec, the indices must be in-bounds:
965 // 0 <= idx < dim_size, and additionally all offsets are non-negative,
966 // hence inbounds and nuw are used when lowering to llvm.getelementptr.
967 Value dataPtr =
968 getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getMemref(),
969 adaptor.getIndices(), kNoWrapFlags);
970 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
971 op.getAlignment().value_or(0),
972 false, op.getNontemporal());
973 return success();
974 }
975};
976
977// The prefetch operation is lowered in a way similar to the load operation
978// except that the llvm.prefetch operation is used for replacement.
979struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
980 using Base::Base;
981
982 LogicalResult
983 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
984 ConversionPatternRewriter &rewriter) const override {
985 auto type = prefetchOp.getMemRefType();
986 auto loc = prefetchOp.getLoc();
987
988 Value dataPtr = getStridedElementPtr(
989 rewriter, loc, type, adaptor.getMemref(), adaptor.getIndices());
990
991 // Replace with llvm.prefetch.
992 IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
993 IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
994 IntegerAttr isData =
995 rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache());
996 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
997 localityHint, isData);
998 return success();
999 }
1000};
1001
1002struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
1004
1005 LogicalResult
1006 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
1007 ConversionPatternRewriter &rewriter) const override {
1008 Location loc = op.getLoc();
1009 Type operandType = op.getMemref().getType();
1010 if (isa<UnrankedMemRefType>(operandType)) {
1011 UnrankedMemRefDescriptor desc(adaptor.getMemref());
1012 rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
1013 return success();
1014 }
1015 if (auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
1016 Type indexType = getIndexType();
1017 rewriter.replaceOp(op,
1018 {createIndexAttrConstant(rewriter, loc, indexType,
1019 rankedMemRefType.getRank())});
1020 return success();
1021 }
1022 return failure();
1023 }
1024};
1025
1026struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
1028
1029 LogicalResult
1030 matchAndRewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
1031 ConversionPatternRewriter &rewriter) const override {
1032 Type srcType = memRefCastOp.getOperand().getType();
1033 Type dstType = memRefCastOp.getType();
1034
1035 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
1036 // used for type erasure. For now they must preserve underlying element type
1037 // and require source and result type to have the same rank. Therefore,
1038 // perform a sanity check that the underlying structs are the same. Once op
1039 // semantics are relaxed we can revisit.
1040 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
1041 if (typeConverter->convertType(srcType) !=
1042 typeConverter->convertType(dstType))
1043 return failure();
1044
1045 // Unranked to unranked cast is disallowed
1046 if (isa<UnrankedMemRefType>(srcType) && isa<UnrankedMemRefType>(dstType))
1047 return failure();
1048
1049 auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
1050 auto loc = memRefCastOp.getLoc();
1051
1052 // For ranked/ranked case, just keep the original descriptor.
1053 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType)) {
1054 rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
1055 return success();
1056 }
1057
1058 if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
1059 // Casting ranked to unranked memref type
1060 // Set the rank in the destination from the memref type
1061 // Allocate space on the stack and copy the src memref descriptor
1062 // Set the ptr in the destination to the stack space
1063 auto srcMemRefType = cast<MemRefType>(srcType);
1064 int64_t rank = srcMemRefType.getRank();
1065 // ptr = AllocaOp sizeof(MemRefDescriptor)
1066 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
1067 loc, adaptor.getSource(), rewriter);
1068
1069 // rank = ConstantOp srcRank
1070 auto rankVal = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1071 rewriter.getIndexAttr(rank));
1072 // poison = PoisonOp
1073 UnrankedMemRefDescriptor memRefDesc =
1074 UnrankedMemRefDescriptor::poison(rewriter, loc, targetStructType);
1075 // d1 = InsertValueOp poison, rank, 0
1076 memRefDesc.setRank(rewriter, loc, rankVal);
1077 // d2 = InsertValueOp d1, ptr, 1
1078 memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
1079 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
1080
1081 } else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
1082 // Casting from unranked type to ranked.
1083 // The operation is assumed to be doing a correct cast. If the destination
1084 // type mismatches the unranked the type, it is undefined behavior.
1085 UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
1086 // ptr = ExtractValueOp src, 1
1087 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
1088
1089 // struct = LoadOp ptr
1090 auto loadOp = LLVM::LoadOp::create(rewriter, loc, targetStructType, ptr);
1091 rewriter.replaceOp(memRefCastOp, loadOp.getResult());
1092 } else {
1093 llvm_unreachable("Unsupported unranked memref to unranked memref cast");
1094 }
1095
1096 return success();
1097 }
1098};
1099
1100/// Pattern to lower a `memref.copy` to llvm.
1101///
1102/// For memrefs with identity layouts, the copy is lowered to the llvm
1103/// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
1104/// to the generic `MemrefCopyFn`.
1105class MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
1106 SymbolTableCollection *symbolTables = nullptr;
1107
1108public:
1109 explicit MemRefCopyOpLowering(const LLVMTypeConverter &typeConverter,
1110 SymbolTableCollection *symbolTables = nullptr,
1111 PatternBenefit benefit = 1)
1112 : ConvertOpToLLVMPattern<memref::CopyOp>(typeConverter, benefit),
1113 symbolTables(symbolTables) {}
1114
1115 LogicalResult
1116 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
1117 ConversionPatternRewriter &rewriter) const {
1118 auto loc = op.getLoc();
1119 auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
1120
1121 MemRefDescriptor srcDesc(adaptor.getSource());
1122
1123 // Compute number of elements.
1124 Value numElements = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1125 rewriter.getIndexAttr(1));
1126 for (int pos = 0; pos < srcType.getRank(); ++pos) {
1127 auto size = srcDesc.size(rewriter, loc, pos);
1128 numElements = LLVM::MulOp::create(rewriter, loc, numElements, size);
1129 }
1130
1131 // Get element size.
1132 auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
1133 // Compute total.
1134 Value totalSize =
1135 LLVM::MulOp::create(rewriter, loc, numElements, sizeInBytes);
1136
1137 Type elementType = typeConverter->convertType(srcType.getElementType());
1138
1139 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
1140 Value srcOffset = srcDesc.offset(rewriter, loc);
1141 Value srcPtr = LLVM::GEPOp::create(rewriter, loc, srcBasePtr.getType(),
1142 elementType, srcBasePtr, srcOffset);
1143 MemRefDescriptor targetDesc(adaptor.getTarget());
1144 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
1145 Value targetOffset = targetDesc.offset(rewriter, loc);
1146 Value targetPtr =
1147 LLVM::GEPOp::create(rewriter, loc, targetBasePtr.getType(), elementType,
1148 targetBasePtr, targetOffset);
1149 LLVM::MemcpyOp::create(rewriter, loc, targetPtr, srcPtr, totalSize,
1150 /*isVolatile=*/false);
1151 rewriter.eraseOp(op);
1152
1153 return success();
1154 }
1155
1156 LogicalResult
1157 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
1158 ConversionPatternRewriter &rewriter) const {
1159 auto loc = op.getLoc();
1160 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1161 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1162
1163 // First make sure we have an unranked memref descriptor representation.
1164 auto makeUnranked = [&, this](Value ranked, MemRefType type) {
1165 auto rank = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1166 type.getRank());
1167 auto *typeConverter = getTypeConverter();
1168 auto ptr =
1169 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
1170
1171 auto unrankedType =
1172 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
1174 rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, ptr});
1175 };
1176
1177 // Save stack position before promoting descriptors
1178 auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType());
1179
1180 auto srcMemRefType = dyn_cast<MemRefType>(srcType);
1181 Value unrankedSource =
1182 srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
1183 : adaptor.getSource();
1184 auto targetMemRefType = dyn_cast<MemRefType>(targetType);
1185 Value unrankedTarget =
1186 targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
1187 : adaptor.getTarget();
1188
1189 // Now promote the unranked descriptors to the stack.
1190 auto one = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1191 rewriter.getIndexAttr(1));
1192 auto promote = [&](Value desc) {
1193 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1194 auto allocated =
1195 LLVM::AllocaOp::create(rewriter, loc, ptrType, desc.getType(), one);
1196 LLVM::StoreOp::create(rewriter, loc, desc, allocated);
1197 return allocated;
1198 };
1199
1200 auto sourcePtr = promote(unrankedSource);
1201 auto targetPtr = promote(unrankedTarget);
1202
1203 // Derive size from llvm.getelementptr which will account for any
1204 // potential alignment
1205 auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
1207 rewriter, op->getParentOfType<ModuleOp>(), getIndexType(),
1208 sourcePtr.getType(), symbolTables);
1209 if (failed(copyFn))
1210 return failure();
1211 LLVM::CallOp::create(rewriter, loc, copyFn.value(),
1212 ValueRange{elemSize, sourcePtr, targetPtr});
1213
1214 // Restore stack used for descriptors
1215 LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp);
1216
1217 rewriter.eraseOp(op);
1218
1219 return success();
1220 }
1221
1222 LogicalResult
1223 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
1224 ConversionPatternRewriter &rewriter) const override {
1225 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1226 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1227
1228 auto isContiguousMemrefType = [&](BaseMemRefType type) {
1229 auto memrefType = dyn_cast<mlir::MemRefType>(type);
1230 // We can use memcpy for memrefs if they have an identity layout or are
1231 // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
1232 // special case handled by memrefCopy.
1233 return memrefType &&
1234 (memrefType.getLayout().isIdentity() ||
1235 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
1237 };
1238
1239 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
1240 return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
1241
1242 return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
1243 }
1244};
1245
1246struct MemorySpaceCastOpLowering
1247 : public ConvertOpToLLVMPattern<memref::MemorySpaceCastOp> {
1249 memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern;
1250
1251 LogicalResult
1252 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
1253 ConversionPatternRewriter &rewriter) const override {
1254 Location loc = op.getLoc();
1255
1256 Type resultType = op.getDest().getType();
1257 if (auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
1258 auto convertedType =
1259 typeConverter->convertType<LLVM::LLVMStructType>(resultTypeR);
1260 if (!convertedType)
1261 return rewriter.notifyMatchFailure(op, "memref type conversion failed");
1262 Type newPtrType = convertedType.getBody()[0];
1263
1264 SmallVector<Value> descVals;
1265 MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
1266 descVals);
1267 descVals[0] =
1268 LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[0]);
1269 descVals[1] =
1270 LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[1]);
1271 Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
1272 resultTypeR, descVals);
1273 rewriter.replaceOp(op, result);
1274 return success();
1275 }
1276 if (auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
1277 // Since the type converter won't be doing this for us, get the address
1278 // space.
1279 auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
1280 FailureOr<unsigned> maybeSourceAddrSpace =
1281 getTypeConverter()->getMemRefAddressSpace(sourceType);
1282 if (failed(maybeSourceAddrSpace))
1283 return rewriter.notifyMatchFailure(loc,
1284 "non-integer source address space");
1285 unsigned sourceAddrSpace = *maybeSourceAddrSpace;
1286 FailureOr<unsigned> maybeResultAddrSpace =
1287 getTypeConverter()->getMemRefAddressSpace(resultTypeU);
1288 if (failed(maybeResultAddrSpace))
1289 return rewriter.notifyMatchFailure(loc,
1290 "non-integer result address space");
1291 unsigned resultAddrSpace = *maybeResultAddrSpace;
1292
1293 UnrankedMemRefDescriptor sourceDesc(adaptor.getSource());
1294 Value rank = sourceDesc.rank(rewriter, loc);
1295 Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
1296
1297 // Create and allocate storage for new memref descriptor.
1299 rewriter, loc, typeConverter->convertType(resultTypeU));
1300 result.setRank(rewriter, loc, rank);
1301 Value resultUnderlyingSize = UnrankedMemRefDescriptor::computeSize(
1302 rewriter, loc, *getTypeConverter(), result, resultAddrSpace);
1303 Value resultUnderlyingDesc =
1304 LLVM::AllocaOp::create(rewriter, loc, getPtrType(),
1305 rewriter.getI8Type(), resultUnderlyingSize);
1306 result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
1307
1308 // Copy pointers, performing address space casts.
1309 auto sourceElemPtrType =
1310 LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace);
1311 auto resultElemPtrType =
1312 LLVM::LLVMPointerType::get(rewriter.getContext(), resultAddrSpace);
1313
1314 Value allocatedPtr = sourceDesc.allocatedPtr(
1315 rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
1316 Value alignedPtr =
1317 sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
1318 sourceUnderlyingDesc, sourceElemPtrType);
1319 allocatedPtr = LLVM::AddrSpaceCastOp::create(
1320 rewriter, loc, resultElemPtrType, allocatedPtr);
1321 alignedPtr = LLVM::AddrSpaceCastOp::create(rewriter, loc,
1322 resultElemPtrType, alignedPtr);
1323
1324 result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
1325 resultElemPtrType, allocatedPtr);
1326 result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
1327 resultUnderlyingDesc, resultElemPtrType, alignedPtr);
1328
1329 // Copy all the index-valued operands.
1330 Value sourceIndexVals =
1331 sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1332 sourceUnderlyingDesc, sourceElemPtrType);
1333 Value resultIndexVals =
1334 result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1335 resultUnderlyingDesc, resultElemPtrType);
1336
1337 int64_t bytesToSkip =
1338 2 * llvm::divideCeil(
1339 getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
1340 Value bytesToSkipConst = LLVM::ConstantOp::create(
1341 rewriter, loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
1342 Value copySize =
1343 LLVM::SubOp::create(rewriter, loc, getIndexType(),
1344 resultUnderlyingSize, bytesToSkipConst);
1345 LLVM::MemcpyOp::create(rewriter, loc, resultIndexVals, sourceIndexVals,
1346 copySize, /*isVolatile=*/false);
1347
1348 rewriter.replaceOp(op, ValueRange{result});
1349 return success();
1350 }
1351 return rewriter.notifyMatchFailure(loc, "unexpected memref type");
1352 }
1353};
1354
1355/// Extracts allocated, aligned pointers and offset from a ranked or unranked
1356/// memref type. In unranked case, the fields are extracted from the underlying
1357/// ranked descriptor.
1358static void extractPointersAndOffset(Location loc,
1359 ConversionPatternRewriter &rewriter,
1360 const LLVMTypeConverter &typeConverter,
1361 Value originalOperand,
1362 Value convertedOperand,
1363 Value *allocatedPtr, Value *alignedPtr,
1364 Value *offset = nullptr) {
1365 Type operandType = originalOperand.getType();
1366 if (isa<MemRefType>(operandType)) {
1367 MemRefDescriptor desc(convertedOperand);
1368 *allocatedPtr = desc.allocatedPtr(rewriter, loc);
1369 *alignedPtr = desc.alignedPtr(rewriter, loc);
1370 if (offset != nullptr)
1371 *offset = desc.offset(rewriter, loc);
1372 return;
1373 }
1374
1375 // These will all cause assert()s on unconvertible types.
1376 unsigned memorySpace = *typeConverter.getMemRefAddressSpace(
1377 cast<UnrankedMemRefType>(operandType));
1378 auto elementPtrType =
1379 LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace);
1380
1381 // Extract pointer to the underlying ranked memref descriptor and cast it to
1382 // ElemType**.
1383 UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
1384 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1385
1387 rewriter, loc, underlyingDescPtr, elementPtrType);
1389 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1390 if (offset != nullptr) {
1392 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1393 }
1394}
1395
1396struct MemRefReinterpretCastOpLowering
1397 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1399 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1400
1401 LogicalResult
1402 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1403 ConversionPatternRewriter &rewriter) const override {
1404 Type srcType = castOp.getSource().getType();
1405
1406 Value descriptor;
1407 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1408 adaptor, &descriptor)))
1409 return failure();
1410 rewriter.replaceOp(castOp, {descriptor});
1411 return success();
1412 }
1413
1414private:
1415 LogicalResult convertSourceMemRefToDescriptor(
1416 ConversionPatternRewriter &rewriter, Type srcType,
1417 memref::ReinterpretCastOp castOp,
1418 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1419 MemRefType targetMemRefType =
1420 cast<MemRefType>(castOp.getResult().getType());
1421 auto llvmTargetDescriptorTy =
1422 typeConverter->convertType<LLVM::LLVMStructType>(targetMemRefType);
1423 if (!llvmTargetDescriptorTy)
1424 return failure();
1425
1426 // Create descriptor.
1427 Location loc = castOp.getLoc();
1428 auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1429
1430 // Set allocated and aligned pointers.
1431 Value allocatedPtr, alignedPtr;
1432 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1433 castOp.getSource(), adaptor.getSource(),
1434 &allocatedPtr, &alignedPtr);
1435 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1436 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1437
1438 // Set offset.
1439 if (castOp.isDynamicOffset(0))
1440 desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1441 else
1442 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1443
1444 // Set sizes and strides.
1445 unsigned dynSizeId = 0;
1446 unsigned dynStrideId = 0;
1447 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1448 if (castOp.isDynamicSize(i))
1449 desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1450 else
1451 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1452
1453 if (castOp.isDynamicStride(i))
1454 desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1455 else
1456 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1457 }
1458 *descriptor = desc;
1459 return success();
1460 }
1461};
1462
1463struct MemRefReshapeOpLowering
1464 : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1465 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
1466
1467 LogicalResult
1468 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1469 ConversionPatternRewriter &rewriter) const override {
1470 Type srcType = reshapeOp.getSource().getType();
1471
1472 Value descriptor;
1473 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1474 adaptor, &descriptor)))
1475 return failure();
1476 rewriter.replaceOp(reshapeOp, {descriptor});
1477 return success();
1478 }
1479
1480private:
1481 LogicalResult
1482 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1483 Type srcType, memref::ReshapeOp reshapeOp,
1484 memref::ReshapeOp::Adaptor adaptor,
1485 Value *descriptor) const {
1486 auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
1487 if (shapeMemRefType.hasStaticShape()) {
1488 MemRefType targetMemRefType =
1489 cast<MemRefType>(reshapeOp.getResult().getType());
1490 auto llvmTargetDescriptorTy =
1491 typeConverter->convertType<LLVM::LLVMStructType>(targetMemRefType);
1492 if (!llvmTargetDescriptorTy)
1493 return failure();
1494
1495 // Create descriptor.
1496 Location loc = reshapeOp.getLoc();
1497 auto desc =
1498 MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1499
1500 // Set allocated and aligned pointers.
1501 Value allocatedPtr, alignedPtr;
1502 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1503 reshapeOp.getSource(), adaptor.getSource(),
1504 &allocatedPtr, &alignedPtr);
1505 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1506 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1507
1508 // Extract the offset and strides from the type.
1509 int64_t offset;
1510 SmallVector<int64_t> strides;
1511 if (failed(targetMemRefType.getStridesAndOffset(strides, offset)))
1512 return rewriter.notifyMatchFailure(
1513 reshapeOp, "failed to get stride and offset exprs");
1514
1515 if (!isStaticStrideOrOffset(offset))
1516 return rewriter.notifyMatchFailure(reshapeOp,
1517 "dynamic offset is unsupported");
1518
1519 desc.setConstantOffset(rewriter, loc, offset);
1520
1521 assert(targetMemRefType.getLayout().isIdentity() &&
1522 "Identity layout map is a precondition of a valid reshape op");
1523
1524 Type indexType = getIndexType();
1525 Value stride = nullptr;
1526 int64_t targetRank = targetMemRefType.getRank();
1527 for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1528 if (ShapedType::isStatic(strides[i])) {
1529 // If the stride for this dimension is dynamic, then use the product
1530 // of the sizes of the inner dimensions.
1531 stride =
1532 createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
1533 } else if (!stride) {
1534 // `stride` is null only in the first iteration of the loop. However,
1535 // since the target memref has an identity layout, we can safely set
1536 // the innermost stride to 1.
1537 stride = createIndexAttrConstant(rewriter, loc, indexType, 1);
1538 }
1539
1540 Value dimSize;
1541 // If the size of this dimension is dynamic, then load it at runtime
1542 // from the shape operand.
1543 if (!targetMemRefType.isDynamicDim(i)) {
1544 dimSize = createIndexAttrConstant(rewriter, loc, indexType,
1545 targetMemRefType.getDimSize(i));
1546 } else {
1547 Value shapeOp = reshapeOp.getShape();
1548 Value index = createIndexAttrConstant(rewriter, loc, indexType, i);
1549 dimSize = memref::LoadOp::create(rewriter, loc, shapeOp, index);
1550 Type indexType = getIndexType();
1551 if (dimSize.getType() != indexType)
1552 dimSize = typeConverter->materializeTargetConversion(
1553 rewriter, loc, indexType, dimSize);
1554 assert(dimSize && "Invalid memref element type");
1555 }
1556
1557 desc.setSize(rewriter, loc, i, dimSize);
1558 desc.setStride(rewriter, loc, i, stride);
1559
1560 // Prepare the stride value for the next dimension.
1561 stride = LLVM::MulOp::create(rewriter, loc, stride, dimSize);
1562 }
1563
1564 *descriptor = desc;
1565 return success();
1566 }
1567
1568 // The shape is a rank-1 tensor with unknown length.
1569 Location loc = reshapeOp.getLoc();
1570 MemRefDescriptor shapeDesc(adaptor.getShape());
1571 Value resultRank = shapeDesc.size(rewriter, loc, 0);
1572
1573 // Extract address space and element type.
1574 auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1575 unsigned addressSpace =
1576 *getTypeConverter()->getMemRefAddressSpace(targetType);
1577
1578 // Create the unranked memref descriptor that holds the ranked one. The
1579 // inner descriptor is allocated on stack.
1580 auto targetDesc = UnrankedMemRefDescriptor::poison(
1581 rewriter, loc, typeConverter->convertType(targetType));
1582 targetDesc.setRank(rewriter, loc, resultRank);
1584 rewriter, loc, *getTypeConverter(), targetDesc, addressSpace);
1585 Value underlyingDescPtr = LLVM::AllocaOp::create(
1586 rewriter, loc, getPtrType(), IntegerType::get(getContext(), 8),
1587 allocationSize);
1588 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1589
1590 // Extract pointers and offset from the source memref.
1591 Value allocatedPtr, alignedPtr, offset;
1592 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1593 reshapeOp.getSource(), adaptor.getSource(),
1594 &allocatedPtr, &alignedPtr, &offset);
1595
1596 // Set pointers and offset.
1597 auto elementPtrType =
1598 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
1599
1600 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1601 elementPtrType, allocatedPtr);
1602 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1603 underlyingDescPtr, elementPtrType,
1604 alignedPtr);
1605 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1606 underlyingDescPtr, elementPtrType,
1607 offset);
1608
1609 // Use the offset pointer as base for further addressing. Copy over the new
1610 // shape and compute strides. For this, we create a loop from rank-1 to 0.
1612 rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
1614 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1615 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1616 Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1);
1617 Value resultRankMinusOne =
1618 LLVM::SubOp::create(rewriter, loc, resultRank, oneIndex);
1619
1620 Block *initBlock = rewriter.getInsertionBlock();
1621 Type indexType = getTypeConverter()->getIndexType();
1622 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1623
1624 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1625 {indexType, indexType}, {loc, loc});
1626
1627 // Move the remaining initBlock ops to condBlock.
1628 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1629 rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1630
1631 rewriter.setInsertionPointToEnd(initBlock);
1632 LLVM::BrOp::create(rewriter, loc,
1633 ValueRange({resultRankMinusOne, oneIndex}), condBlock);
1634 rewriter.setInsertionPointToStart(condBlock);
1635 Value indexArg = condBlock->getArgument(0);
1636 Value strideArg = condBlock->getArgument(1);
1637
1638 Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0);
1639 Value pred = LLVM::ICmpOp::create(
1640 rewriter, loc, IntegerType::get(rewriter.getContext(), 1),
1641 LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1642
1643 Block *bodyBlock =
1644 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1645 rewriter.setInsertionPointToStart(bodyBlock);
1646
1647 // Copy size from shape to descriptor.
1648 auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1649 Value sizeLoadGep = LLVM::GEPOp::create(
1650 rewriter, loc, llvmIndexPtrType,
1651 typeConverter->convertType(shapeMemRefType.getElementType()),
1652 shapeOperandPtr, indexArg);
1653 Value size = LLVM::LoadOp::create(rewriter, loc, indexType, sizeLoadGep);
1654 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1655 targetSizesBase, indexArg, size);
1656
1657 // Write stride value and compute next one.
1658 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1659 targetStridesBase, indexArg, strideArg);
1660 Value nextStride = LLVM::MulOp::create(rewriter, loc, strideArg, size);
1661
1662 // Decrement loop counter and branch back.
1663 Value decrement = LLVM::SubOp::create(rewriter, loc, indexArg, oneIndex);
1664 LLVM::BrOp::create(rewriter, loc, ValueRange({decrement, nextStride}),
1665 condBlock);
1666
1667 Block *remainder =
1668 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1669
1670 // Hook up the cond exit to the remainder.
1671 rewriter.setInsertionPointToEnd(condBlock);
1672 LLVM::CondBrOp::create(rewriter, loc, pred, bodyBlock, ValueRange(),
1673 remainder, ValueRange());
1674
1675 // Reset position to beginning of new remainder block.
1676 rewriter.setInsertionPointToStart(remainder);
1677
1678 *descriptor = targetDesc;
1679 return success();
1680 }
1681};
1682
1683/// RessociatingReshapeOp must be expanded before we reach this stage.
1684/// Report that information.
1685template <typename ReshapeOp>
1686class ReassociatingReshapeOpConversion
1687 : public ConvertOpToLLVMPattern<ReshapeOp> {
1688public:
1690 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1691
1692 LogicalResult
1693 matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1694 ConversionPatternRewriter &rewriter) const override {
1695 return rewriter.notifyMatchFailure(
1696 reshapeOp,
1697 "reassociation operations should have been expanded beforehand");
1698 }
1699};
1700
1701/// Subviews must be expanded before we reach this stage.
1702/// Report that information.
1703struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1704 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1705
1706 LogicalResult
1707 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1708 ConversionPatternRewriter &rewriter) const override {
1709 return rewriter.notifyMatchFailure(
1710 subViewOp, "subview operations should have been expanded beforehand");
1711 }
1712};
1713
1714/// Conversion pattern that transforms a transpose op into:
1715/// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
1716/// 2. A load of the ViewDescriptor from the pointer allocated in 1.
1717/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1718/// and stride. Size and stride are permutations of the original values.
1719/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1720/// The transpose op is replaced by the alloca'ed pointer.
1721class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1722public:
1723 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1724
1725 LogicalResult
1726 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1727 ConversionPatternRewriter &rewriter) const override {
1728 auto loc = transposeOp.getLoc();
1729 MemRefDescriptor viewMemRef(adaptor.getIn());
1730
1731 // No permutation, early exit.
1732 if (transposeOp.getPermutation().isIdentity())
1733 return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1734
1735 auto targetMemRef = MemRefDescriptor::poison(
1736 rewriter, loc,
1737 typeConverter->convertType(transposeOp.getIn().getType()));
1738
1739 // Copy the base and aligned pointers from the old descriptor to the new
1740 // one.
1741 targetMemRef.setAllocatedPtr(rewriter, loc,
1742 viewMemRef.allocatedPtr(rewriter, loc));
1743 targetMemRef.setAlignedPtr(rewriter, loc,
1744 viewMemRef.alignedPtr(rewriter, loc));
1745
1746 // Copy the offset pointer from the old descriptor to the new one.
1747 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1748
1749 // Iterate over the dimensions and apply size/stride permutation:
1750 // When enumerating the results of the permutation map, the enumeration
1751 // index is the index into the target dimensions and the DimExpr points to
1752 // the dimension of the source memref.
1753 for (const auto &en :
1754 llvm::enumerate(transposeOp.getPermutation().getResults())) {
1755 int targetPos = en.index();
1756 int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
1757 targetMemRef.setSize(rewriter, loc, targetPos,
1758 viewMemRef.size(rewriter, loc, sourcePos));
1759 targetMemRef.setStride(rewriter, loc, targetPos,
1760 viewMemRef.stride(rewriter, loc, sourcePos));
1761 }
1762
1763 rewriter.replaceOp(transposeOp, {targetMemRef});
1764 return success();
1765 }
1766};
1767
1768/// Conversion pattern that transforms an op into:
1769/// 1. An `llvm.mlir.undef` operation to create a memref descriptor
1770/// 2. Updates to the descriptor to introduce the data ptr, offset, size
1771/// and stride.
1772/// The view op is replaced by the descriptor.
1773struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1775
1776 // Build and return the value for the idx^th shape dimension, either by
1777 // returning the constant shape dimension or counting the proper dynamic size.
1778 Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1779 ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx,
1780 Type indexType) const {
1781 assert(idx < shape.size());
1782 if (ShapedType::isStatic(shape[idx]))
1783 return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]);
1784 // Count the number of dynamic dims in range [0, idx]
1785 unsigned nDynamic =
1786 llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1787 return dynamicSizes[nDynamic];
1788 }
1789
1790 // Build and return the idx^th stride, either by returning the constant stride
1791 // or by computing the dynamic stride from the current `runningStride` and
1792 // `nextSize`. The caller should keep a running stride and update it with the
1793 // result returned by this function.
1794 Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1795 ArrayRef<int64_t> strides, Value nextSize,
1796 Value runningStride, unsigned idx, Type indexType) const {
1797 assert(idx < strides.size());
1798 if (ShapedType::isStatic(strides[idx]))
1799 return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]);
1800 if (nextSize)
1801 return runningStride
1802 ? LLVM::MulOp::create(rewriter, loc, runningStride, nextSize)
1803 : nextSize;
1804 assert(!runningStride);
1805 return createIndexAttrConstant(rewriter, loc, indexType, 1);
1806 }
1807
1808 LogicalResult
1809 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1810 ConversionPatternRewriter &rewriter) const override {
1811 auto loc = viewOp.getLoc();
1812
1813 auto viewMemRefType = viewOp.getType();
1814 auto targetElementTy =
1815 typeConverter->convertType(viewMemRefType.getElementType());
1816 auto targetDescTy = typeConverter->convertType(viewMemRefType);
1817 if (!targetDescTy || !targetElementTy ||
1818 !LLVM::isCompatibleType(targetElementTy) ||
1819 !LLVM::isCompatibleType(targetDescTy))
1820 return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1821 failure();
1822
1823 int64_t offset;
1825 auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset);
1826 if (failed(successStrides))
1827 return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1828 assert(offset == 0 && "expected offset to be 0");
1829
1830 // Target memref must be contiguous in memory (innermost stride is 1), or
1831 // empty (special case when at least one of the memref dimensions is 0).
1832 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1833 return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1834 failure();
1835
1836 // Create the descriptor.
1837 MemRefDescriptor sourceMemRef(adaptor.getSource());
1838 auto targetMemRef = MemRefDescriptor::poison(rewriter, loc, targetDescTy);
1839
1840 // Field 1: Copy the allocated pointer, used for malloc/free.
1841 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1842 auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1843 targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
1844
1845 // Field 2: Copy the actual aligned pointer to payload.
1846 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1847 alignedPtr = LLVM::GEPOp::create(
1848 rewriter, loc, alignedPtr.getType(),
1849 typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
1850 adaptor.getByteShift());
1851
1852 targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
1853
1854 Type indexType = getIndexType();
1855 // Field 3: The offset in the resulting type must be 0. This is
1856 // because of the type change: an offset on srcType* may not be
1857 // expressible as an offset on dstType*.
1858 targetMemRef.setOffset(
1859 rewriter, loc,
1860 createIndexAttrConstant(rewriter, loc, indexType, offset));
1861
1862 // Early exit for 0-D corner case.
1863 if (viewMemRefType.getRank() == 0)
1864 return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1865
1866 // Fields 4 and 5: Update sizes and strides.
1867 Value stride = nullptr, nextSize = nullptr;
1868 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1869 // Update size.
1870 Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1871 adaptor.getSizes(), i, indexType);
1872 targetMemRef.setSize(rewriter, loc, i, size);
1873 // Update stride.
1874 stride =
1875 getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
1876 targetMemRef.setStride(rewriter, loc, i, stride);
1877 nextSize = size;
1878 }
1879
1880 rewriter.replaceOp(viewOp, {targetMemRef});
1881 return success();
1882 }
1883};
1884
1885//===----------------------------------------------------------------------===//
1886// AtomicRMWOpLowering
1887//===----------------------------------------------------------------------===//
1888
1889/// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1890/// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1891static std::optional<LLVM::AtomicBinOp>
1892matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1893 switch (atomicOp.getKind()) {
1894 case arith::AtomicRMWKind::addf:
1895 return LLVM::AtomicBinOp::fadd;
1896 case arith::AtomicRMWKind::addi:
1897 return LLVM::AtomicBinOp::add;
1898 case arith::AtomicRMWKind::assign:
1899 return LLVM::AtomicBinOp::xchg;
1900 case arith::AtomicRMWKind::maximumf:
1901 // TODO: remove this by end of 2025.
1902 LDBG() << "the lowering of memref.atomicrmw maximumf changed "
1903 "from fmax to fmaximum, expect more NaNs";
1904 return LLVM::AtomicBinOp::fmaximum;
1905 case arith::AtomicRMWKind::maxnumf:
1906 return LLVM::AtomicBinOp::fmax;
1907 case arith::AtomicRMWKind::maxs:
1908 return LLVM::AtomicBinOp::max;
1909 case arith::AtomicRMWKind::maxu:
1910 return LLVM::AtomicBinOp::umax;
1911 case arith::AtomicRMWKind::minimumf:
1912 // TODO: remove this by end of 2025.
1913 LDBG() << "the lowering of memref.atomicrmw minimum changed "
1914 "from fmin to fminimum, expect more NaNs";
1915 return LLVM::AtomicBinOp::fminimum;
1916 case arith::AtomicRMWKind::minnumf:
1917 return LLVM::AtomicBinOp::fmin;
1918 case arith::AtomicRMWKind::mins:
1919 return LLVM::AtomicBinOp::min;
1920 case arith::AtomicRMWKind::minu:
1921 return LLVM::AtomicBinOp::umin;
1922 case arith::AtomicRMWKind::ori:
1923 return LLVM::AtomicBinOp::_or;
1924 case arith::AtomicRMWKind::xori:
1925 return LLVM::AtomicBinOp::_xor;
1926 case arith::AtomicRMWKind::andi:
1927 return LLVM::AtomicBinOp::_and;
1928 default:
1929 return std::nullopt;
1930 }
1931 llvm_unreachable("Invalid AtomicRMWKind");
1932}
1933
1934struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1935 using Base::Base;
1936
1937 LogicalResult
1938 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1939 ConversionPatternRewriter &rewriter) const override {
1940 auto maybeKind = matchSimpleAtomicOp(atomicOp);
1941 if (!maybeKind)
1942 return failure();
1943 auto memRefType = atomicOp.getMemRefType();
1944 SmallVector<int64_t> strides;
1945 int64_t offset;
1946 if (failed(memRefType.getStridesAndOffset(strides, offset)))
1947 return failure();
1948 auto dataPtr =
1949 getStridedElementPtr(rewriter, atomicOp.getLoc(), memRefType,
1950 adaptor.getMemref(), adaptor.getIndices());
1951 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1952 atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1953 LLVM::AtomicOrdering::acq_rel);
1954 return success();
1955 }
1956};
1957
1958/// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index.
1959class ConvertExtractAlignedPointerAsIndex
1960 : public ConvertOpToLLVMPattern<memref::ExtractAlignedPointerAsIndexOp> {
1961public:
1963 memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern;
1964
1965 LogicalResult
1966 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1967 OpAdaptor adaptor,
1968 ConversionPatternRewriter &rewriter) const override {
1969 BaseMemRefType sourceTy = extractOp.getSource().getType();
1970
1971 Value alignedPtr;
1972 if (sourceTy.hasRank()) {
1973 MemRefDescriptor desc(adaptor.getSource());
1974 alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc());
1975 } else {
1976 auto elementPtrTy = LLVM::LLVMPointerType::get(
1977 rewriter.getContext(), sourceTy.getMemorySpaceAsInt());
1978
1979 UnrankedMemRefDescriptor desc(adaptor.getSource());
1980 Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc());
1981
1983 rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
1984 elementPtrTy);
1985 }
1986
1987 rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
1988 extractOp, getTypeConverter()->getIndexType(), alignedPtr);
1989 return success();
1990 }
1991};
1992
1993/// Materialize the MemRef descriptor represented by the results of
1994/// ExtractStridedMetadataOp.
1995class ExtractStridedMetadataOpLowering
1996 : public ConvertOpToLLVMPattern<memref::ExtractStridedMetadataOp> {
1997public:
1999 memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern;
2000
2001 LogicalResult
2002 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
2003 OpAdaptor adaptor,
2004 ConversionPatternRewriter &rewriter) const override {
2005
2006 if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
2007 return failure();
2008
2009 // Create the descriptor.
2010 MemRefDescriptor sourceMemRef(adaptor.getSource());
2011 Location loc = extractStridedMetadataOp.getLoc();
2012 Value source = extractStridedMetadataOp.getSource();
2013
2014 auto sourceMemRefType = cast<MemRefType>(source.getType());
2015 int64_t rank = sourceMemRefType.getRank();
2016 SmallVector<Value> results;
2017 results.reserve(2 + rank * 2);
2018
2019 // Base buffer.
2020 Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc);
2021 Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
2023 rewriter, loc, *getTypeConverter(),
2024 cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
2025 baseBuffer, alignedBuffer);
2026 results.push_back((Value)dstMemRef);
2027
2028 // Offset.
2029 results.push_back(sourceMemRef.offset(rewriter, loc));
2030
2031 // Sizes.
2032 for (unsigned i = 0; i < rank; ++i)
2033 results.push_back(sourceMemRef.size(rewriter, loc, i));
2034 // Strides.
2035 for (unsigned i = 0; i < rank; ++i)
2036 results.push_back(sourceMemRef.stride(rewriter, loc, i));
2037
2038 rewriter.replaceOp(extractStridedMetadataOp, results);
2039 return success();
2040 }
2041};
2042
2043} // namespace
2044
2046 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
2047 SymbolTableCollection *symbolTables) {
2048 // clang-format off
2049 patterns.add<
2050 AllocaOpLowering,
2051 AllocaScopeOpLowering,
2052 AssumeAlignmentOpLowering,
2053 AtomicRMWOpLowering,
2054 ConvertExtractAlignedPointerAsIndex,
2055 DimOpLowering,
2056 DistinctObjectsOpLowering,
2057 ExtractStridedMetadataOpLowering,
2058 GenericAtomicRMWOpLowering,
2059 GetGlobalMemrefOpLowering,
2060 LoadOpLowering,
2061 MemRefCastOpLowering,
2062 MemRefReinterpretCastOpLowering,
2063 MemRefReshapeOpLowering,
2064 MemorySpaceCastOpLowering,
2065 PrefetchOpLowering,
2066 RankOpLowering,
2067 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
2068 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
2069 StoreOpLowering,
2070 SubViewOpLowering,
2072 ViewOpLowering>(converter);
2073 // clang-format on
2074 patterns.add<GlobalMemrefOpLowering, MemRefCopyOpLowering>(converter,
2075 symbolTables);
2076 auto allocLowering = converter.getOptions().allocLowering;
2078 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter,
2079 symbolTables);
2080 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
2081 patterns.add<AllocOpLowering, DeallocOpLowering>(converter, symbolTables);
2082}
2083
2084namespace {
2085struct FinalizeMemRefToLLVMConversionPass
2086 : public impl::FinalizeMemRefToLLVMConversionPassBase<
2087 FinalizeMemRefToLLVMConversionPass> {
2088 using FinalizeMemRefToLLVMConversionPassBase::
2089 FinalizeMemRefToLLVMConversionPassBase;
2090
2091 void runOnOperation() override {
2092 Operation *op = getOperation();
2093 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
2095 dataLayoutAnalysis.getAtOrAbove(op));
2096 options.allocLowering =
2099
2100 options.useGenericFunctions = useGenericFunctions;
2101
2102 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
2103 options.overrideIndexBitwidth(indexBitwidth);
2104
2105 LLVMTypeConverter typeConverter(&getContext(), options,
2106 &dataLayoutAnalysis);
2107 RewritePatternSet patterns(&getContext());
2108 SymbolTableCollection symbolTables;
2109 populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns,
2110 &symbolTables);
2112 target.addLegalOp<func::FuncOp>();
2113 if (failed(applyPartialConversion(op, target, std::move(patterns))))
2114 signalPassFailure();
2115 }
2116};
2117
2118/// Implement the interface to convert MemRef to LLVM.
2119struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
2120 MemRefToLLVMDialectInterface(Dialect *dialect)
2121 : ConvertToLLVMPatternInterface(dialect) {}
2122
2123 void loadDependentDialects(MLIRContext *context) const final {
2124 context->loadDialect<LLVM::LLVMDialect>();
2125 }
2126
2127 /// Hook for derived dialect interface to provide conversion patterns
2128 /// and mark dialect legal for the conversion target.
2129 void populateConvertToLLVMConversionPatterns(
2130 ConversionTarget &target, LLVMTypeConverter &typeConverter,
2131 RewritePatternSet &patterns) const final {
2132 populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
2133 }
2134};
2135
2136} // namespace
2137
2139 registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
2140 dialect->addInterfaces<MemRefToLLVMDialectInterface>();
2141 });
2142}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:150
BlockArgument getArgument(unsigned i)
Definition Block.h:139
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:233
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:58
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) const
Promote the LLVM struct representation of one MemRef descriptor to stack and use pointer to struct to...
const LowerToLLVMOptions & getOptions() const
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
const DataLayoutAnalysis * getDataLayoutAnalysis() const
Returns the data layout analysis to query during conversion.
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Options to control the LLVM lowering.
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value bufferPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type)
Builds IR for getting the start address of the buffer represented by this memref: memref....
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
Value offset(OpBuilder &builder, Location loc)
Builds IR extracting the offset from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:376
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:274
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class represents a collection of SymbolTables.
virtual SymbolTable & getSymbolTable(Operation *op)
Lookup, or create, a symbol table for an operation.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
static void setOffset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
TODO: The following accessors don't take alignment rules between elements of the descriptor struct in...
static Value computeSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, UnrankedMemRefDescriptor desc, unsigned addressSpace)
Builds and returns IR computing the size in bytes (suitable for opaque allocation).
void setRank(OpBuilder &builder, Location loc, Value value)
Builds IR setting the rank in the descriptor.
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
static void setSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void setAlignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
Value rank(OpBuilder &builder, Location loc) const
Builds IR extracting the rank from the descriptor.
static Value offsetBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR for getting the pointer to the offset's location.
static Value offset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the offset from the descriptor.
static Value strideBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value)
Builds IR setting ranked memref descriptor ptr.
static void setStride(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
static Value sizeBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the pointer to the first element of the size array.
static Value alignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the aligned pointer from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType, Type unrankedDescriptorType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
Include the generated interface declarations.
void registerConvertMemRefToLLVMInterface(DialectRegistry &registry)
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.