MLIR 23.0.0git
MemRefToLLVM.cpp
Go to the documentation of this file.
1//===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
23#include "mlir/IR/AffineMap.h"
25#include "mlir/IR/IRMapping.h"
26#include "mlir/Pass/Pass.h"
27#include "llvm/Support/DebugLog.h"
28#include "llvm/Support/MathExtras.h"
29
30#include <optional>
31
32#define DEBUG_TYPE "memref-to-llvm"
33
34namespace mlir {
35#define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
36#include "mlir/Conversion/Passes.h.inc"
37} // namespace mlir
38
39using namespace mlir;
40
41static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags =
42 LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw;
43
44namespace {
45
46static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
47 return ShapedType::isStatic(strideOrOffset);
48}
49
50static FailureOr<LLVM::LLVMFuncOp>
51getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
52 Operation *module, SymbolTableCollection *symbolTables) {
53 bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
54
55 if (useGenericFn)
56 return LLVM::lookupOrCreateGenericFreeFn(b, module, symbolTables);
57
58 return LLVM::lookupOrCreateFreeFn(b, module, symbolTables);
59}
60
61static FailureOr<LLVM::LLVMFuncOp>
62getNotalignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
63 Operation *module, Type indexType,
64 SymbolTableCollection *symbolTables) {
65 bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
66 if (useGenericFn)
67 return LLVM::lookupOrCreateGenericAllocFn(b, module, indexType,
68 symbolTables);
69
70 return LLVM::lookupOrCreateMallocFn(b, module, indexType, symbolTables);
71}
72
73static FailureOr<LLVM::LLVMFuncOp>
74getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
75 Operation *module, Type indexType,
76 SymbolTableCollection *symbolTables) {
77 bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
78
79 if (useGenericFn)
80 return LLVM::lookupOrCreateGenericAlignedAllocFn(b, module, indexType,
81 symbolTables);
82
83 return LLVM::lookupOrCreateAlignedAllocFn(b, module, indexType, symbolTables);
84}
85
86/// Computes the aligned value for 'input' as follows:
87/// bumped = input + alignement - 1
88/// aligned = bumped - bumped % alignment
89static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
90 Value input, Value alignment) {
91 Value one = LLVM::ConstantOp::create(rewriter, loc, alignment.getType(),
92 rewriter.getIndexAttr(1));
93 Value bump = LLVM::SubOp::create(rewriter, loc, alignment, one);
94 Value bumped = LLVM::AddOp::create(rewriter, loc, input, bump);
95 Value mod = LLVM::URemOp::create(rewriter, loc, bumped, alignment);
96 return LLVM::SubOp::create(rewriter, loc, bumped, mod);
97}
98
99/// Computes the byte size for the MemRef element type.
100static unsigned getMemRefEltSizeInBytes(const LLVMTypeConverter *typeConverter,
101 MemRefType memRefType, Operation *op,
102 const DataLayout *defaultLayout) {
103 const DataLayout *layout = defaultLayout;
104 if (const DataLayoutAnalysis *analysis =
105 typeConverter->getDataLayoutAnalysis()) {
106 layout = &analysis->getAbove(op);
107 }
108 Type elementType = memRefType.getElementType();
109 if (auto memRefElementType = dyn_cast<MemRefType>(elementType))
110 return typeConverter->getMemRefDescriptorSize(memRefElementType, *layout);
111 if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
112 return typeConverter->getUnrankedMemRefDescriptorSize(memRefElementType,
113 *layout);
114 return layout->getTypeSize(elementType);
115}
116
117static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
118 Location loc, Value allocatedPtr,
119 MemRefType memRefType, Type elementPtrType,
120 const LLVMTypeConverter &typeConverter) {
121 auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType());
122 FailureOr<unsigned> maybeMemrefAddrSpace =
123 typeConverter.getMemRefAddressSpace(memRefType);
124 assert(succeeded(maybeMemrefAddrSpace) && "unsupported address space");
125 unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
126 if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
127 allocatedPtr = LLVM::AddrSpaceCastOp::create(
128 rewriter, loc,
129 LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
130 allocatedPtr);
131 return allocatedPtr;
132}
133
134class AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
135 SymbolTableCollection *symbolTables = nullptr;
136
137public:
138 explicit AllocOpLowering(const LLVMTypeConverter &typeConverter,
139 SymbolTableCollection *symbolTables = nullptr,
140 PatternBenefit benefit = 1)
141 : ConvertOpToLLVMPattern<memref::AllocOp>(typeConverter, benefit),
142 symbolTables(symbolTables) {}
143
144 LogicalResult
145 matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
146 ConversionPatternRewriter &rewriter) const override {
147 auto loc = op.getLoc();
148 MemRefType memRefType = op.getType();
149 if (!isConvertibleAndHasIdentityMaps(memRefType))
150 return rewriter.notifyMatchFailure(op, "incompatible memref type");
151
152 // Get or insert alloc function into the module.
153 FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
154 getNotalignedAllocFn(rewriter, getTypeConverter(),
155 op->getParentWithTrait<OpTrait::SymbolTable>(),
156 getIndexType(), symbolTables);
157 if (failed(allocFuncOp))
158 return failure();
159
160 // Get actual sizes of the memref as values: static sizes are constant
161 // values and dynamic sizes are passed to 'alloc' as operands. In case of
162 // zero-dimensional memref, assume a scalar (size 1).
164 SmallVector<Value, 4> strides;
165 Value sizeBytes;
166
167 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
168 rewriter, sizes, strides, sizeBytes, true);
169
170 Value alignment = getAlignment(rewriter, loc, op);
171 if (alignment) {
172 // Adjust the allocation size to consider alignment.
173 sizeBytes = LLVM::AddOp::create(rewriter, loc, sizeBytes, alignment);
174 }
175
176 // Allocate the underlying buffer.
177 Type elementPtrType = this->getElementPtrType(memRefType);
178 assert(elementPtrType && "could not compute element ptr type");
179 auto results =
180 LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(), sizeBytes);
181
182 Value allocatedPtr =
183 castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
184 elementPtrType, *getTypeConverter());
185 Value alignedPtr = allocatedPtr;
186 if (alignment) {
187 // Compute the aligned pointer.
188 Value allocatedInt =
189 LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), allocatedPtr);
190 Value alignmentInt =
191 createAligned(rewriter, loc, allocatedInt, alignment);
192 alignedPtr =
193 LLVM::IntToPtrOp::create(rewriter, loc, elementPtrType, alignmentInt);
194 }
195
196 // Create the MemRef descriptor.
197 auto memRefDescriptor = this->createMemRefDescriptor(
198 loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
199
200 // Return the final value of the descriptor.
201 rewriter.replaceOp(op, {memRefDescriptor});
202 return success();
203 }
204
205 /// Computes the alignment for the given memory allocation op.
206 template <typename OpType>
207 Value getAlignment(ConversionPatternRewriter &rewriter, Location loc,
208 OpType op) const {
209 MemRefType memRefType = op.getType();
210 Value alignment;
211 if (auto alignmentAttr = op.getAlignment()) {
212 Type indexType = getIndexType();
213 alignment =
214 createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr);
215 } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
216 // In the case where no alignment is specified, we may want to override
217 // `malloc's` behavior. `malloc` typically aligns at the size of the
218 // biggest scalar on a target HW. For non-scalars, use the natural
219 // alignment of the LLVM type given by the LLVM DataLayout.
220 alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
221 }
222 return alignment;
223 }
224};
225
226class AlignedAllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
227 SymbolTableCollection *symbolTables = nullptr;
228
229public:
230 explicit AlignedAllocOpLowering(const LLVMTypeConverter &typeConverter,
231 SymbolTableCollection *symbolTables = nullptr,
232 PatternBenefit benefit = 1)
233 : ConvertOpToLLVMPattern<memref::AllocOp>(typeConverter, benefit),
234 symbolTables(symbolTables) {}
235
236 LogicalResult
237 matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
238 ConversionPatternRewriter &rewriter) const override {
239 auto loc = op.getLoc();
240 MemRefType memRefType = op.getType();
241 if (!isConvertibleAndHasIdentityMaps(memRefType))
242 return rewriter.notifyMatchFailure(op, "incompatible memref type");
243
244 // Get or insert alloc function into module.
245 FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
246 getAlignedAllocFn(rewriter, getTypeConverter(),
247 op->getParentWithTrait<OpTrait::SymbolTable>(),
248 getIndexType(), symbolTables);
249 if (failed(allocFuncOp))
250 return failure();
251
252 // Get actual sizes of the memref as values: static sizes are constant
253 // values and dynamic sizes are passed to 'alloc' as operands. In case of
254 // zero-dimensional memref, assume a scalar (size 1).
256 SmallVector<Value, 4> strides;
257 Value sizeBytes;
258
259 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
260 rewriter, sizes, strides, sizeBytes, !false);
261
262 int64_t alignment = alignedAllocationGetAlignment(op, &defaultLayout);
263
264 Value allocAlignment =
265 createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
266
267 // Function aligned_alloc requires size to be a multiple of alignment; we
268 // pad the size to the next multiple if necessary.
269 if (!isMemRefSizeMultipleOf(memRefType, alignment, op, &defaultLayout))
270 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
271
272 Type elementPtrType = this->getElementPtrType(memRefType);
273 auto results =
274 LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(),
275 ValueRange({allocAlignment, sizeBytes}));
276
277 Value ptr =
278 castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
279 elementPtrType, *getTypeConverter());
280
281 // Create the MemRef descriptor.
282 auto memRefDescriptor = this->createMemRefDescriptor(
283 loc, memRefType, ptr, ptr, sizes, strides, rewriter);
284
285 // Return the final value of the descriptor.
286 rewriter.replaceOp(op, {memRefDescriptor});
287 return success();
288 }
289
290 /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
291 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
292
293 /// Computes the alignment for aligned_alloc used to allocate the buffer for
294 /// the memory allocation op.
295 ///
296 /// Aligned_alloc requires the allocation size to be a power of two, and the
297 /// allocation size to be a multiple of the alignment.
298 int64_t alignedAllocationGetAlignment(memref::AllocOp op,
299 const DataLayout *defaultLayout) const {
300 if (std::optional<uint64_t> alignment = op.getAlignment())
301 return *alignment;
302
303 // Whenever we don't have alignment set, we will use an alignment
304 // consistent with the element type; since the allocation size has to be a
305 // power of two, we will bump to the next power of two if it isn't.
306 unsigned eltSizeBytes = getMemRefEltSizeInBytes(
307 getTypeConverter(), op.getType(), op, defaultLayout);
308 return std::max(kMinAlignedAllocAlignment,
309 llvm::PowerOf2Ceil(eltSizeBytes));
310 }
311
312 /// Returns true if the memref size in bytes is known to be a multiple of
313 /// factor.
314 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, Operation *op,
315 const DataLayout *defaultLayout) const {
316 uint64_t sizeDivisor =
317 getMemRefEltSizeInBytes(getTypeConverter(), type, op, defaultLayout);
318 for (unsigned i = 0, e = type.getRank(); i < e; i++) {
319 if (type.isDynamicDim(i))
320 continue;
321 sizeDivisor = sizeDivisor * type.getDimSize(i);
322 }
323 return sizeDivisor % factor == 0;
324 }
325
326private:
327 /// Default layout to use in absence of the corresponding analysis.
328 DataLayout defaultLayout;
329};
330
331struct AllocaOpLowering : public ConvertOpToLLVMPattern<memref::AllocaOp> {
333
334 /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
335 /// is set to null for stack allocations. `accessAlignment` is set if
336 /// alignment is needed post allocation (for eg. in conjunction with malloc).
337 LogicalResult
338 matchAndRewrite(memref::AllocaOp op, OpAdaptor adaptor,
339 ConversionPatternRewriter &rewriter) const override {
340 auto loc = op.getLoc();
341 MemRefType memRefType = op.getType();
342 if (!isConvertibleAndHasIdentityMaps(memRefType))
343 return rewriter.notifyMatchFailure(op, "incompatible memref type");
344
345 // Get actual sizes of the memref as values: static sizes are constant
346 // values and dynamic sizes are passed to 'alloc' as operands. In case of
347 // zero-dimensional memref, assume a scalar (size 1).
349 SmallVector<Value, 4> strides;
350 Value size;
351
352 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
353 rewriter, sizes, strides, size, !true);
354
355 // With alloca, one gets a pointer to the element type right away.
356 // For stack allocations.
357 auto elementType =
358 typeConverter->convertType(op.getType().getElementType());
359 FailureOr<unsigned> maybeAddressSpace =
360 getTypeConverter()->getMemRefAddressSpace(op.getType());
361 assert(succeeded(maybeAddressSpace) && "unsupported address space");
362 unsigned addrSpace = *maybeAddressSpace;
363 auto elementPtrType =
364 LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
365
366 auto allocatedElementPtr =
367 LLVM::AllocaOp::create(rewriter, loc, elementPtrType, elementType, size,
368 op.getAlignment().value_or(0));
369
370 // Create the MemRef descriptor.
371 auto memRefDescriptor = this->createMemRefDescriptor(
372 loc, memRefType, allocatedElementPtr, allocatedElementPtr, sizes,
373 strides, rewriter);
374
375 // Return the final value of the descriptor.
376 rewriter.replaceOp(op, {memRefDescriptor});
377 return success();
378 }
379};
380
381struct AllocaScopeOpLowering
382 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
383 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
384
385 LogicalResult
386 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
387 ConversionPatternRewriter &rewriter) const override {
388 OpBuilder::InsertionGuard guard(rewriter);
389 Location loc = allocaScopeOp.getLoc();
390
391 // Split the current block before the AllocaScopeOp to create the inlining
392 // point.
393 auto *currentBlock = rewriter.getInsertionBlock();
394 auto *remainingOpsBlock =
395 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
396 Block *continueBlock;
397 if (allocaScopeOp.getNumResults() == 0) {
398 continueBlock = remainingOpsBlock;
399 } else {
400 continueBlock = rewriter.createBlock(
401 remainingOpsBlock, allocaScopeOp.getResultTypes(),
402 SmallVector<Location>(allocaScopeOp->getNumResults(),
403 allocaScopeOp.getLoc()));
404 LLVM::BrOp::create(rewriter, loc, ValueRange(), remainingOpsBlock);
405 }
406
407 // Inline body region.
408 Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
409 Block *afterBody = &allocaScopeOp.getBodyRegion().back();
410 rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
411
412 // Save stack and then branch into the body of the region.
413 rewriter.setInsertionPointToEnd(currentBlock);
414 auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType());
415 LLVM::BrOp::create(rewriter, loc, ValueRange(), beforeBody);
416
417 // Replace the alloca_scope return with a branch that jumps out of the body.
418 // Stack restore before leaving the body region.
419 rewriter.setInsertionPointToEnd(afterBody);
420 auto returnOp =
421 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
422 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
423 returnOp, returnOp.getResults(), continueBlock);
424
425 // Insert stack restore before jumping out the body of the region.
426 rewriter.setInsertionPoint(branchOp);
427 LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp);
428
429 // Replace the op with values return from the body region.
430 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
431
432 return success();
433 }
434};
435
436struct AssumeAlignmentOpLowering
437 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
439 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
440 explicit AssumeAlignmentOpLowering(const LLVMTypeConverter &converter)
442
443 LogicalResult
444 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
445 ConversionPatternRewriter &rewriter) const override {
446 Value memref = adaptor.getMemref();
447 unsigned alignment = op.getAlignment();
448 auto loc = op.getLoc();
449
450 auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
451 Value ptr = getStridedElementPtr(rewriter, loc, srcMemRefType, memref,
452 /*indices=*/{});
453
454 // Emit llvm.assume(true) ["align"(memref, alignment)].
455 // This is more direct than ptrtoint-based checks, is explicitly supported,
456 // and works with non-integral address spaces.
457 Value trueCond =
458 LLVM::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(true));
459 Value alignmentConst =
460 createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
461 LLVM::AssumeOp::create(rewriter, loc, trueCond, LLVM::AssumeAlignTag(), ptr,
462 alignmentConst);
463 rewriter.replaceOp(op, memref);
464 return success();
465 }
466};
467
468struct DistinctObjectsOpLowering
469 : public ConvertOpToLLVMPattern<memref::DistinctObjectsOp> {
471 memref::DistinctObjectsOp>::ConvertOpToLLVMPattern;
472 explicit DistinctObjectsOpLowering(const LLVMTypeConverter &converter)
474
475 LogicalResult
476 matchAndRewrite(memref::DistinctObjectsOp op, OpAdaptor adaptor,
477 ConversionPatternRewriter &rewriter) const override {
478 ValueRange operands = adaptor.getOperands();
479 if (operands.size() <= 1) {
480 // Fast path.
481 rewriter.replaceOp(op, operands);
482 return success();
483 }
484
485 Location loc = op.getLoc();
487 for (auto [origOperand, newOperand] :
488 llvm::zip_equal(op.getOperands(), operands)) {
489 auto memrefType = cast<MemRefType>(origOperand.getType());
490 MemRefDescriptor memRefDescriptor(newOperand);
491 Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
492 memrefType);
493 ptrs.push_back(ptr);
494 }
495
496 auto cond =
497 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), 1);
498 // Generate separate_storage assumptions for each pair of pointers.
499 for (auto i : llvm::seq<size_t>(ptrs.size() - 1)) {
500 for (auto j : llvm::seq<size_t>(i + 1, ptrs.size())) {
501 Value ptr1 = ptrs[i];
502 Value ptr2 = ptrs[j];
503 LLVM::AssumeOp::create(rewriter, loc, cond,
504 LLVM::AssumeSeparateStorageTag{}, ptr1, ptr2);
505 }
506 }
507
508 rewriter.replaceOp(op, operands);
509 return success();
510 }
511};
512
513// A `dealloc` is converted into a call to `free` on the underlying data buffer.
514// The memref descriptor being an SSA value, there is no need to clean it up
515// in any way.
516class DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
517 SymbolTableCollection *symbolTables = nullptr;
518
519public:
520 explicit DeallocOpLowering(const LLVMTypeConverter &typeConverter,
521 SymbolTableCollection *symbolTables = nullptr,
522 PatternBenefit benefit = 1)
523 : ConvertOpToLLVMPattern<memref::DeallocOp>(typeConverter, benefit),
524 symbolTables(symbolTables) {}
525
526 LogicalResult
527 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
528 ConversionPatternRewriter &rewriter) const override {
529 // Insert the `free` declaration if it is not already present.
530 FailureOr<LLVM::LLVMFuncOp> freeFunc =
531 getFreeFn(rewriter, getTypeConverter(),
532 op->getParentWithTrait<OpTrait::SymbolTable>(), symbolTables);
533 if (failed(freeFunc))
534 return failure();
535 Value allocatedPtr;
536 if (auto unrankedTy =
537 llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
538 auto elementPtrTy = LLVM::LLVMPointerType::get(
539 rewriter.getContext(), unrankedTy.getMemorySpaceAsInt());
541 rewriter, op.getLoc(),
542 UnrankedMemRefDescriptor(adaptor.getMemref())
543 .memRefDescPtr(rewriter, op.getLoc()),
544 elementPtrTy);
545 } else {
546 allocatedPtr = MemRefDescriptor(adaptor.getMemref())
547 .allocatedPtr(rewriter, op.getLoc());
548 }
549 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc.value(),
550 allocatedPtr);
551 return success();
552 }
553};
554
555// A `dim` is converted to a constant for static sizes and to an access to the
556// size stored in the memref descriptor for dynamic sizes.
557struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
559
560 LogicalResult
561 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
562 ConversionPatternRewriter &rewriter) const override {
563 Type operandType = dimOp.getSource().getType();
564 if (isa<UnrankedMemRefType>(operandType)) {
565 FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
566 operandType, dimOp, adaptor.getOperands(), rewriter);
567 if (failed(extractedSize))
568 return failure();
569 rewriter.replaceOp(dimOp, {*extractedSize});
570 return success();
571 }
572 if (isa<MemRefType>(operandType)) {
573 rewriter.replaceOp(
574 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
575 adaptor.getOperands(), rewriter)});
576 return success();
577 }
578 llvm_unreachable("expected MemRefType or UnrankedMemRefType");
579 }
580
581private:
582 FailureOr<Value>
583 extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
584 OpAdaptor adaptor,
585 ConversionPatternRewriter &rewriter) const {
586 Location loc = dimOp.getLoc();
587
588 auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
589 auto scalarMemRefType =
590 MemRefType::get({}, unrankedMemRefType.getElementType());
591 FailureOr<unsigned> maybeAddressSpace =
592 getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
593 if (failed(maybeAddressSpace)) {
594 dimOp.emitOpError("memref memory space must be convertible to an integer "
595 "address space");
596 return failure();
597 }
598 unsigned addressSpace = *maybeAddressSpace;
599
600 // Extract pointer to the underlying ranked descriptor and bitcast it to a
601 // memref<element_type> descriptor pointer to minimize the number of GEP
602 // operations.
603 UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
604 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
605
606 Type elementType = typeConverter->convertType(scalarMemRefType);
607
608 // Get pointer to offset field of memref<element_type> descriptor.
609 auto indexPtrTy =
610 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
611 Value offsetPtr =
612 LLVM::GEPOp::create(rewriter, loc, indexPtrTy, elementType,
613 underlyingRankedDesc, ArrayRef<LLVM::GEPArg>{0, 2});
614
615 // The size value that we have to extract can be obtained using GEPop with
616 // `dimOp.index() + 1` index argument.
617 Value idxPlusOne = LLVM::AddOp::create(
618 rewriter, loc,
619 createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
620 adaptor.getIndex());
621 Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy,
622 getTypeConverter()->getIndexType(),
623 offsetPtr, idxPlusOne);
624 return LLVM::LoadOp::create(rewriter, loc,
625 getTypeConverter()->getIndexType(), sizePtr)
626 .getResult();
627 }
628
629 std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
630 if (auto idx = dimOp.getConstantIndex())
631 return idx;
632
633 if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
634 return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
635
636 return std::nullopt;
637 }
638
639 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
640 OpAdaptor adaptor,
641 ConversionPatternRewriter &rewriter) const {
642 Location loc = dimOp.getLoc();
643
644 // Take advantage if index is constant.
645 MemRefType memRefType = cast<MemRefType>(operandType);
646 Type indexType = getIndexType();
647 if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
648 int64_t i = *index;
649 if (i >= 0 && i < memRefType.getRank()) {
650 if (memRefType.isDynamicDim(i)) {
651 // extract dynamic size from the memref descriptor.
652 MemRefDescriptor descriptor(adaptor.getSource());
653 return descriptor.size(rewriter, loc, i);
654 }
655 // Use constant for static size.
656 int64_t dimSize = memRefType.getDimSize(i);
657 return createIndexAttrConstant(rewriter, loc, indexType, dimSize);
658 }
659 }
660 Value index = adaptor.getIndex();
661 int64_t rank = memRefType.getRank();
662 MemRefDescriptor memrefDescriptor(adaptor.getSource());
663 return memrefDescriptor.size(rewriter, loc, index, rank);
664 }
665};
666
667/// Common base for load and store operations on MemRefs. Restricts the match
668/// to supported MemRef types. Provides functionality to emit code accessing a
669/// specific element of the underlying data buffer.
670template <typename Derived>
671struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
673 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
674 using Base = LoadStoreOpLowering<Derived>;
675};
676
677/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
678/// retried until it succeeds in atomically storing a new value into memory.
679///
680/// +---------------------------------+
681/// | <code before the AtomicRMWOp> |
682/// | <compute initial %loaded> |
683/// | cf.br loop(%loaded) |
684/// +---------------------------------+
685/// |
686/// -------| |
687/// | v v
688/// | +--------------------------------+
689/// | | loop(%loaded): |
690/// | | <body contents> |
691/// | | %pair = cmpxchg |
692/// | | %ok = %pair[0] |
693/// | | %new = %pair[1] |
694/// | | cf.cond_br %ok, end, loop(%new) |
695/// | +--------------------------------+
696/// | | |
697/// |----------- |
698/// v
699/// +--------------------------------+
700/// | end: |
701/// | <code after the AtomicRMWOp> |
702/// +--------------------------------+
703///
704struct GenericAtomicRMWOpLowering
705 : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
706 using Base::Base;
707
708 LogicalResult
709 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
710 ConversionPatternRewriter &rewriter) const override {
711 auto loc = atomicOp.getLoc();
712 Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
713
714 // `llvm.cmpxchg` only supports integer or pointer operands. For
715 // floating-point element types, perform the CAS on a same-width integer
716 // and bitcast at the boundaries.
717 bool needsBitcast = isa<FloatType>(valueType);
718 Type cmpxchgType = valueType;
719 if (needsBitcast) {
720 unsigned bitWidth = cast<FloatType>(valueType).getWidth();
721 cmpxchgType = rewriter.getIntegerType(bitWidth);
722 }
723
724 // Split the block into initial, loop, and ending parts.
725 auto *initBlock = rewriter.getInsertionBlock();
726 auto *loopBlock = rewriter.splitBlock(initBlock, Block::iterator(atomicOp));
727 loopBlock->addArgument(cmpxchgType, loc);
728
729 auto *endBlock =
730 rewriter.splitBlock(loopBlock, Block::iterator(atomicOp)++);
731
732 // Compute the loaded value and branch to the loop block.
733 rewriter.setInsertionPointToEnd(initBlock);
734 auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
735 auto dataPtr = getStridedElementPtr(
736 rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
737 Value init = LLVM::LoadOp::create(
738 rewriter, loc, typeConverter->convertType(memRefType.getElementType()),
739 dataPtr);
740 if (needsBitcast)
741 init = LLVM::BitcastOp::create(rewriter, loc, cmpxchgType, init);
742 LLVM::BrOp::create(rewriter, loc, init, loopBlock);
743
744 // Prepare the body of the loop block.
745 rewriter.setInsertionPointToStart(loopBlock);
746
747 // Clone the GenericAtomicRMWOp region and extract the result.
748 Value loopArgument = loopBlock->getArgument(0);
749 Value loopArgForBody = loopArgument;
750 if (needsBitcast)
751 loopArgForBody =
752 LLVM::BitcastOp::create(rewriter, loc, valueType, loopArgument);
753 IRMapping mapping;
754 mapping.map(atomicOp.getCurrentValue(), loopArgForBody);
755 Block &entryBlock = atomicOp.body().front();
756 for (auto &nestedOp : entryBlock.without_terminator()) {
757 Operation *clone = rewriter.clone(nestedOp, mapping);
758 mapping.map(nestedOp.getResults(), clone->getResults());
759 }
760
761 Value result =
762 mapping.lookupOrNull(entryBlock.getTerminator()->getOperand(0));
763 if (!result) {
764 return atomicOp.emitError("result not defined in region");
765 }
766 if (needsBitcast)
767 result = LLVM::BitcastOp::create(rewriter, loc, cmpxchgType, result);
768
769 // Prepare the epilog of the loop block.
770 // Append the cmpxchg op to the end of the loop block.
771 auto successOrdering = LLVM::AtomicOrdering::acq_rel;
772 auto failureOrdering = LLVM::AtomicOrdering::monotonic;
773 auto cmpxchg =
774 LLVM::AtomicCmpXchgOp::create(rewriter, loc, dataPtr, loopArgument,
775 result, successOrdering, failureOrdering);
776 // Extract the %new_loaded and %ok values from the pair.
777 Value newLoaded = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 0);
778 Value ok = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 1);
779
780 // Conditionally branch to the end or back to the loop depending on %ok.
781 LLVM::CondBrOp::create(rewriter, loc, ok, endBlock, ArrayRef<Value>(),
782 loopBlock, newLoaded);
783
784 // The 'result' of the atomic_rmw op is the newly loaded value. Bitcast
785 // back to the float type if needed. Insert at the start of `endBlock` so
786 // the bitcast precedes the existing terminator (split into endBlock).
787 if (needsBitcast) {
788 rewriter.setInsertionPointToStart(endBlock);
789 newLoaded = LLVM::BitcastOp::create(rewriter, loc, valueType, newLoaded);
790 }
791 rewriter.setInsertionPointToEnd(endBlock);
792 rewriter.replaceOp(atomicOp, {newLoaded});
793
794 return success();
795 }
796};
797
798/// Returns the LLVM type of the global variable given the memref type `type`.
799static Type
800convertGlobalMemrefTypeToLLVM(MemRefType type,
801 const LLVMTypeConverter &typeConverter) {
802 // LLVM type for a global memref will be a multi-dimension array. For
803 // declarations or uninitialized global memrefs, we can potentially flatten
804 // this to a 1D array. However, for memref.global's with an initial value,
805 // we do not intend to flatten the ElementsAttribute when going from std ->
806 // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
807 Type elementType = typeConverter.convertType(type.getElementType());
808 Type arrayTy = elementType;
809 // Shape has the outermost dim at index 0, so need to walk it backwards
810 for (int64_t dim : llvm::reverse(type.getShape()))
811 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
812 return arrayTy;
813}
814
815/// GlobalMemrefOp is lowered to a LLVM Global Variable.
816class GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<memref::GlobalOp> {
817 SymbolTableCollection *symbolTables = nullptr;
818
819public:
820 explicit GlobalMemrefOpLowering(const LLVMTypeConverter &typeConverter,
821 SymbolTableCollection *symbolTables = nullptr,
822 PatternBenefit benefit = 1)
823 : ConvertOpToLLVMPattern<memref::GlobalOp>(typeConverter, benefit),
824 symbolTables(symbolTables) {}
825
826 LogicalResult
827 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
828 ConversionPatternRewriter &rewriter) const override {
829 MemRefType type = global.getType();
830 if (!isConvertibleAndHasIdentityMaps(type))
831 return failure();
832
833 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
834
835 LLVM::Linkage linkage =
836 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
837 bool isExternal = global.isExternal();
838 bool isUninitialized = global.isUninitialized();
839
840 Attribute initialValue = nullptr;
841 if (!isExternal && !isUninitialized) {
842 auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
843 initialValue = elementsAttr;
844
845 // For scalar memrefs, the global variable created is of the element type,
846 // so unpack the elements attribute to extract the value.
847 if (type.getRank() == 0)
848 initialValue = elementsAttr.getSplatValue<Attribute>();
849 }
850
851 uint64_t alignment = global.getAlignment().value_or(0);
852 FailureOr<unsigned> addressSpace =
853 getTypeConverter()->getMemRefAddressSpace(type);
854 if (failed(addressSpace))
855 return global.emitOpError(
856 "memory space cannot be converted to an integer address space");
857
858 // Remove old operation from symbol table.
859 SymbolTable *symbolTable = nullptr;
860 if (symbolTables) {
861 Operation *symbolTableOp =
863 symbolTable = &symbolTables->getSymbolTable(symbolTableOp);
864 symbolTable->remove(global);
865 }
866
867 // Create new operation.
868 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
869 global, arrayTy, global.getConstant(), linkage, global.getSymName(),
870 initialValue, alignment, *addressSpace);
871
872 // Insert new operation into symbol table.
873 if (symbolTable)
874 symbolTable->insert(newGlobal, rewriter.getInsertionPoint());
875
876 if (!isExternal && isUninitialized) {
877 rewriter.createBlock(&newGlobal.getInitializerRegion());
878 Value undef[] = {
879 LLVM::UndefOp::create(rewriter, newGlobal.getLoc(), arrayTy)};
880 LLVM::ReturnOp::create(rewriter, newGlobal.getLoc(), undef);
881 }
882 return success();
883 }
884};
885
886/// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
887/// the first element stashed into the descriptor. This reuses
888/// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
889struct GetGlobalMemrefOpLowering
890 : public ConvertOpToLLVMPattern<memref::GetGlobalOp> {
891 using ConvertOpToLLVMPattern<memref::GetGlobalOp>::ConvertOpToLLVMPattern;
892
893 /// Buffer "allocation" for memref.get_global op is getting the address of
894 /// the global variable referenced.
895 LogicalResult
896 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
897 ConversionPatternRewriter &rewriter) const override {
898 auto loc = op.getLoc();
899 MemRefType memRefType = op.getType();
900 if (!isConvertibleAndHasIdentityMaps(memRefType))
901 return rewriter.notifyMatchFailure(op, "incompatible memref type");
902
903 // Get actual sizes of the memref as values: static sizes are constant
904 // values and dynamic sizes are passed to 'alloc' as operands. In case of
905 // zero-dimensional memref, assume a scalar (size 1).
907 SmallVector<Value, 4> strides;
908 Value sizeBytes;
909
910 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
911 rewriter, sizes, strides, sizeBytes, !false);
912
913 MemRefType type = cast<MemRefType>(op.getResult().getType());
914
915 // This is called after a type conversion, which would have failed if this
916 // call fails.
917 FailureOr<unsigned> maybeAddressSpace =
918 getTypeConverter()->getMemRefAddressSpace(type);
919 assert(succeeded(maybeAddressSpace) && "unsupported address space");
920 unsigned memSpace = *maybeAddressSpace;
921
922 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
923 auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace);
924 auto addressOf =
925 LLVM::AddressOfOp::create(rewriter, loc, ptrTy, op.getName());
926
927 // Get the address of the first element in the array by creating a GEP with
928 // the address of the GV as the base, and (rank + 1) number of 0 indices.
929 auto gep =
930 LLVM::GEPOp::create(rewriter, loc, ptrTy, arrayTy, addressOf,
931 SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));
932
933 // We do not expect the memref obtained using `memref.get_global` to be
934 // ever deallocated. Set the allocated pointer to be known bad value to
935 // help debug if that ever happens.
936 auto intPtrType = getIntPtrType(memSpace);
937 Value deadBeefConst =
938 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
939 auto deadBeefPtr =
940 LLVM::IntToPtrOp::create(rewriter, loc, ptrTy, deadBeefConst);
941
942 // Both allocated and aligned pointers are same. We could potentially stash
943 // a nullptr for the allocated pointer since we do not expect any dealloc.
944 // Create the MemRef descriptor.
945 auto memRefDescriptor = this->createMemRefDescriptor(
946 loc, memRefType, deadBeefPtr, gep, sizes, strides, rewriter);
947
948 // Return the final value of the descriptor.
949 rewriter.replaceOp(op, {memRefDescriptor});
950 return success();
951 }
952};
953
954// Load operation is lowered to obtaining a pointer to the indexed element
955// and loading it.
956struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
957 using Base::Base;
958
959 LogicalResult
960 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
961 ConversionPatternRewriter &rewriter) const override {
962 auto type = loadOp.getMemRefType();
963
964 // Per memref.load spec, the indices must be in-bounds:
965 // 0 <= idx < dim_size, and additionally all offsets are non-negative,
966 // hence inbounds and nuw are used when lowering to llvm.getelementptr.
967 Value dataPtr = getStridedElementPtr(rewriter, loadOp.getLoc(), type,
968 adaptor.getMemref(),
969 adaptor.getIndices(), kNoWrapFlags);
970 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
971 loadOp, typeConverter->convertType(type.getElementType()), dataPtr,
972 loadOp.getAlignment().value_or(0), false, loadOp.getNontemporal());
973 return success();
974 }
975};
976
977// Store operation is lowered to obtaining a pointer to the indexed element,
978// and storing the given value to it.
979struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
980 using Base::Base;
981
982 LogicalResult
983 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
984 ConversionPatternRewriter &rewriter) const override {
985 auto type = op.getMemRefType();
986
987 // Per memref.store spec, the indices must be in-bounds:
988 // 0 <= idx < dim_size, and additionally all offsets are non-negative,
989 // hence inbounds and nuw are used when lowering to llvm.getelementptr.
990 Value dataPtr =
991 getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getMemref(),
992 adaptor.getIndices(), kNoWrapFlags);
993 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
994 op.getAlignment().value_or(0),
995 false, op.getNontemporal());
996 return success();
997 }
998};
999
1000// The prefetch operation is lowered in a way similar to the load operation
1001// except that the llvm.prefetch operation is used for replacement.
1002struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
1003 using Base::Base;
1004
1005 LogicalResult
1006 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
1007 ConversionPatternRewriter &rewriter) const override {
1008 auto type = prefetchOp.getMemRefType();
1009 auto loc = prefetchOp.getLoc();
1010
1011 Value dataPtr = getStridedElementPtr(
1012 rewriter, loc, type, adaptor.getMemref(), adaptor.getIndices());
1013
1014 // Replace with llvm.prefetch.
1015 IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
1016 IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
1017 IntegerAttr isData =
1018 rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache());
1019 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
1020 localityHint, isData);
1021 return success();
1022 }
1023};
1024
1025struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
1027
1028 LogicalResult
1029 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
1030 ConversionPatternRewriter &rewriter) const override {
1031 Location loc = op.getLoc();
1032 Type operandType = op.getMemref().getType();
1033 if (isa<UnrankedMemRefType>(operandType)) {
1034 UnrankedMemRefDescriptor desc(adaptor.getMemref());
1035 rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
1036 return success();
1037 }
1038 if (auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
1039 Type indexType = getIndexType();
1040 rewriter.replaceOp(op,
1041 {createIndexAttrConstant(rewriter, loc, indexType,
1042 rankedMemRefType.getRank())});
1043 return success();
1044 }
1045 return failure();
1046 }
1047};
1048
1049struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
1051
1052 LogicalResult
1053 matchAndRewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
1054 ConversionPatternRewriter &rewriter) const override {
1055 Type srcType = memRefCastOp.getOperand().getType();
1056 Type dstType = memRefCastOp.getType();
1057
1058 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
1059 // used for type erasure. For now they must preserve underlying element type
1060 // and require source and result type to have the same rank. Therefore,
1061 // perform a sanity check that the underlying structs are the same. Once op
1062 // semantics are relaxed we can revisit.
1063 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
1064 if (typeConverter->convertType(srcType) !=
1065 typeConverter->convertType(dstType))
1066 return failure();
1067
1068 // Unranked to unranked cast is disallowed
1069 if (isa<UnrankedMemRefType>(srcType) && isa<UnrankedMemRefType>(dstType))
1070 return failure();
1071
1072 auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
1073 auto loc = memRefCastOp.getLoc();
1074
1075 // For ranked/ranked case, just keep the original descriptor.
1076 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType)) {
1077 rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
1078 return success();
1079 }
1080
1081 if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
1082 // Casting ranked to unranked memref type
1083 // Set the rank in the destination from the memref type
1084 // Allocate space on the stack and copy the src memref descriptor
1085 // Set the ptr in the destination to the stack space
1086 auto srcMemRefType = cast<MemRefType>(srcType);
1087 int64_t rank = srcMemRefType.getRank();
1088 // ptr = AllocaOp sizeof(MemRefDescriptor)
1089 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
1090 loc, adaptor.getSource(), rewriter);
1091
1092 // rank = ConstantOp srcRank
1093 auto rankVal = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1094 rewriter.getIndexAttr(rank));
1095 // poison = PoisonOp
1096 UnrankedMemRefDescriptor memRefDesc =
1097 UnrankedMemRefDescriptor::poison(rewriter, loc, targetStructType);
1098 // d1 = InsertValueOp poison, rank, 0
1099 memRefDesc.setRank(rewriter, loc, rankVal);
1100 // d2 = InsertValueOp d1, ptr, 1
1101 memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
1102 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
1103
1104 } else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
1105 // Casting from unranked type to ranked.
1106 // The operation is assumed to be doing a correct cast. If the destination
1107 // type mismatches the unranked the type, it is undefined behavior.
1108 UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
1109 // ptr = ExtractValueOp src, 1
1110 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
1111
1112 // struct = LoadOp ptr
1113 auto loadOp = LLVM::LoadOp::create(rewriter, loc, targetStructType, ptr);
1114 rewriter.replaceOp(memRefCastOp, loadOp.getResult());
1115 } else {
1116 llvm_unreachable("Unsupported unranked memref to unranked memref cast");
1117 }
1118
1119 return success();
1120 }
1121};
1122
1123/// Pattern to lower a `memref.copy` to llvm.
1124///
1125/// For memrefs with identity layouts, the copy is lowered to the llvm
1126/// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
1127/// to the generic `MemrefCopyFn`.
1128class MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
1129 SymbolTableCollection *symbolTables = nullptr;
1130
1131public:
1132 explicit MemRefCopyOpLowering(const LLVMTypeConverter &typeConverter,
1133 SymbolTableCollection *symbolTables = nullptr,
1134 PatternBenefit benefit = 1)
1135 : ConvertOpToLLVMPattern<memref::CopyOp>(typeConverter, benefit),
1136 symbolTables(symbolTables) {}
1137
1138 LogicalResult
1139 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
1140 ConversionPatternRewriter &rewriter) const {
1141 auto loc = op.getLoc();
1142 auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
1143
1144 MemRefDescriptor srcDesc(adaptor.getSource());
1145
1146 // Compute number of elements.
1147 Value numElements = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1148 rewriter.getIndexAttr(1));
1149 for (int pos = 0; pos < srcType.getRank(); ++pos) {
1150 auto size = srcDesc.size(rewriter, loc, pos);
1151 numElements = LLVM::MulOp::create(rewriter, loc, numElements, size);
1152 }
1153
1154 // Get element size.
1155 auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
1156 // Compute total.
1157 Value totalSize =
1158 LLVM::MulOp::create(rewriter, loc, numElements, sizeInBytes);
1159
1160 Type elementType = typeConverter->convertType(srcType.getElementType());
1161
1162 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
1163 Value srcOffset = srcDesc.offset(rewriter, loc);
1164 Value srcPtr = LLVM::GEPOp::create(rewriter, loc, srcBasePtr.getType(),
1165 elementType, srcBasePtr, srcOffset);
1166 MemRefDescriptor targetDesc(adaptor.getTarget());
1167 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
1168 Value targetOffset = targetDesc.offset(rewriter, loc);
1169 Value targetPtr =
1170 LLVM::GEPOp::create(rewriter, loc, targetBasePtr.getType(), elementType,
1171 targetBasePtr, targetOffset);
1172 LLVM::MemcpyOp::create(rewriter, loc, targetPtr, srcPtr, totalSize,
1173 /*isVolatile=*/false);
1174 rewriter.eraseOp(op);
1175
1176 return success();
1177 }
1178
1179 LogicalResult
1180 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
1181 ConversionPatternRewriter &rewriter) const {
1182 auto loc = op.getLoc();
1183 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1184 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1185
1186 // First make sure we have an unranked memref descriptor representation.
1187 auto makeUnranked = [&, this](Value ranked, MemRefType type) {
1188 auto rank = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1189 type.getRank());
1190 auto *typeConverter = getTypeConverter();
1191 auto ptr =
1192 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
1193
1194 auto unrankedType =
1195 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
1197 rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, ptr});
1198 };
1199
1200 // Save stack position before promoting descriptors
1201 auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType());
1202
1203 auto srcMemRefType = dyn_cast<MemRefType>(srcType);
1204 Value unrankedSource =
1205 srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
1206 : adaptor.getSource();
1207 auto targetMemRefType = dyn_cast<MemRefType>(targetType);
1208 Value unrankedTarget =
1209 targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
1210 : adaptor.getTarget();
1211
1212 // Now promote the unranked descriptors to the stack.
1213 auto one = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1214 rewriter.getIndexAttr(1));
1215 auto promote = [&](Value desc) {
1216 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1217 auto allocated =
1218 LLVM::AllocaOp::create(rewriter, loc, ptrType, desc.getType(), one);
1219 LLVM::StoreOp::create(rewriter, loc, desc, allocated);
1220 return allocated;
1221 };
1222
1223 auto sourcePtr = promote(unrankedSource);
1224 auto targetPtr = promote(unrankedTarget);
1225
1226 // Derive size from llvm.getelementptr which will account for any
1227 // potential alignment
1228 auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
1230 rewriter, op->getParentOfType<ModuleOp>(), getIndexType(),
1231 sourcePtr.getType(), symbolTables);
1232 if (failed(copyFn))
1233 return failure();
1234 LLVM::CallOp::create(rewriter, loc, copyFn.value(),
1235 ValueRange{elemSize, sourcePtr, targetPtr});
1236
1237 // Restore stack used for descriptors
1238 LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp);
1239
1240 rewriter.eraseOp(op);
1241
1242 return success();
1243 }
1244
1245 LogicalResult
1246 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
1247 ConversionPatternRewriter &rewriter) const override {
1248 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1249 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1250
1251 auto isContiguousMemrefType = [&](BaseMemRefType type) {
1252 auto memrefType = dyn_cast<mlir::MemRefType>(type);
1253 // We can use memcpy for memrefs if they have an identity layout or are
1254 // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
1255 // special case handled by memrefCopy.
1256 return memrefType &&
1257 (memrefType.getLayout().isIdentity() ||
1258 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
1260 };
1261
1262 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
1263 return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
1264
1265 return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
1266 }
1267};
1268
1269struct MemorySpaceCastOpLowering
1270 : public ConvertOpToLLVMPattern<memref::MemorySpaceCastOp> {
1272 memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern;
1273
1274 LogicalResult
1275 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
1276 ConversionPatternRewriter &rewriter) const override {
1277 Location loc = op.getLoc();
1278
1279 Type resultType = op.getDest().getType();
1280 if (auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
1281 auto convertedType =
1282 typeConverter->convertType<LLVM::LLVMStructType>(resultTypeR);
1283 if (!convertedType)
1284 return rewriter.notifyMatchFailure(op, "memref type conversion failed");
1285 Type newPtrType = convertedType.getBody()[0];
1286
1287 SmallVector<Value> descVals;
1288 MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
1289 descVals);
1290 descVals[0] =
1291 LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[0]);
1292 descVals[1] =
1293 LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[1]);
1294 Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
1295 resultTypeR, descVals);
1296 rewriter.replaceOp(op, result);
1297 return success();
1298 }
1299 if (auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
1300 // Since the type converter won't be doing this for us, get the address
1301 // space.
1302 auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
1303 FailureOr<unsigned> maybeSourceAddrSpace =
1304 getTypeConverter()->getMemRefAddressSpace(sourceType);
1305 if (failed(maybeSourceAddrSpace))
1306 return rewriter.notifyMatchFailure(loc,
1307 "non-integer source address space");
1308 unsigned sourceAddrSpace = *maybeSourceAddrSpace;
1309 FailureOr<unsigned> maybeResultAddrSpace =
1310 getTypeConverter()->getMemRefAddressSpace(resultTypeU);
1311 if (failed(maybeResultAddrSpace))
1312 return rewriter.notifyMatchFailure(loc,
1313 "non-integer result address space");
1314 unsigned resultAddrSpace = *maybeResultAddrSpace;
1315
1316 UnrankedMemRefDescriptor sourceDesc(adaptor.getSource());
1317 Value rank = sourceDesc.rank(rewriter, loc);
1318 Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
1319
1320 // Create and allocate storage for new memref descriptor.
1322 rewriter, loc, typeConverter->convertType(resultTypeU));
1323 result.setRank(rewriter, loc, rank);
1324 Value resultUnderlyingSize = UnrankedMemRefDescriptor::computeSize(
1325 rewriter, loc, *getTypeConverter(), result, resultAddrSpace);
1326 Value resultUnderlyingDesc =
1327 LLVM::AllocaOp::create(rewriter, loc, getPtrType(),
1328 rewriter.getI8Type(), resultUnderlyingSize);
1329 result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
1330
1331 // Copy pointers, performing address space casts.
1332 auto sourceElemPtrType =
1333 LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace);
1334 auto resultElemPtrType =
1335 LLVM::LLVMPointerType::get(rewriter.getContext(), resultAddrSpace);
1336
1337 Value allocatedPtr = sourceDesc.allocatedPtr(
1338 rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
1339 Value alignedPtr =
1340 sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
1341 sourceUnderlyingDesc, sourceElemPtrType);
1342 allocatedPtr = LLVM::AddrSpaceCastOp::create(
1343 rewriter, loc, resultElemPtrType, allocatedPtr);
1344 alignedPtr = LLVM::AddrSpaceCastOp::create(rewriter, loc,
1345 resultElemPtrType, alignedPtr);
1346
1347 result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
1348 resultElemPtrType, allocatedPtr);
1349 result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
1350 resultUnderlyingDesc, resultElemPtrType, alignedPtr);
1351
1352 // Copy all the index-valued operands.
1353 Value sourceIndexVals =
1354 sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1355 sourceUnderlyingDesc, sourceElemPtrType);
1356 Value resultIndexVals =
1357 result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1358 resultUnderlyingDesc, resultElemPtrType);
1359
1360 int64_t bytesToSkip =
1361 2 * llvm::divideCeil(
1362 getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
1363 Value bytesToSkipConst = LLVM::ConstantOp::create(
1364 rewriter, loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
1365 Value copySize =
1366 LLVM::SubOp::create(rewriter, loc, getIndexType(),
1367 resultUnderlyingSize, bytesToSkipConst);
1368 LLVM::MemcpyOp::create(rewriter, loc, resultIndexVals, sourceIndexVals,
1369 copySize, /*isVolatile=*/false);
1370
1371 rewriter.replaceOp(op, ValueRange{result});
1372 return success();
1373 }
1374 return rewriter.notifyMatchFailure(loc, "unexpected memref type");
1375 }
1376};
1377
1378/// Extracts allocated, aligned pointers and offset from a ranked or unranked
1379/// memref type. In unranked case, the fields are extracted from the underlying
1380/// ranked descriptor.
1381static void extractPointersAndOffset(Location loc,
1382 ConversionPatternRewriter &rewriter,
1383 const LLVMTypeConverter &typeConverter,
1384 Value originalOperand,
1385 Value convertedOperand,
1386 Value *allocatedPtr, Value *alignedPtr,
1387 Value *offset = nullptr) {
1388 Type operandType = originalOperand.getType();
1389 if (isa<MemRefType>(operandType)) {
1390 MemRefDescriptor desc(convertedOperand);
1391 *allocatedPtr = desc.allocatedPtr(rewriter, loc);
1392 *alignedPtr = desc.alignedPtr(rewriter, loc);
1393 if (offset != nullptr)
1394 *offset = desc.offset(rewriter, loc);
1395 return;
1396 }
1397
1398 // These will all cause assert()s on unconvertible types.
1399 unsigned memorySpace = *typeConverter.getMemRefAddressSpace(
1400 cast<UnrankedMemRefType>(operandType));
1401 auto elementPtrType =
1402 LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace);
1403
1404 // Extract pointer to the underlying ranked memref descriptor and cast it to
1405 // ElemType**.
1406 UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
1407 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1408
1410 rewriter, loc, underlyingDescPtr, elementPtrType);
1412 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1413 if (offset != nullptr) {
1415 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1416 }
1417}
1418
1419struct MemRefReinterpretCastOpLowering
1420 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1422 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1423
1424 LogicalResult
1425 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1426 ConversionPatternRewriter &rewriter) const override {
1427 Type srcType = castOp.getSource().getType();
1428
1429 Value descriptor;
1430 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1431 adaptor, &descriptor)))
1432 return failure();
1433 rewriter.replaceOp(castOp, {descriptor});
1434 return success();
1435 }
1436
1437private:
1438 LogicalResult convertSourceMemRefToDescriptor(
1439 ConversionPatternRewriter &rewriter, Type srcType,
1440 memref::ReinterpretCastOp castOp,
1441 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1442 MemRefType targetMemRefType =
1443 cast<MemRefType>(castOp.getResult().getType());
1444 auto llvmTargetDescriptorTy =
1445 typeConverter->convertType<LLVM::LLVMStructType>(targetMemRefType);
1446 if (!llvmTargetDescriptorTy)
1447 return failure();
1448
1449 // Create descriptor.
1450 Location loc = castOp.getLoc();
1451 auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1452
1453 // Set allocated and aligned pointers.
1454 Value allocatedPtr, alignedPtr;
1455 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1456 castOp.getSource(), adaptor.getSource(),
1457 &allocatedPtr, &alignedPtr);
1458 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1459 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1460
1461 // Set offset.
1462 if (castOp.isDynamicOffset(0))
1463 desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1464 else
1465 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1466
1467 // Set sizes and strides.
1468 unsigned dynSizeId = 0;
1469 unsigned dynStrideId = 0;
1470 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1471 if (castOp.isDynamicSize(i))
1472 desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1473 else
1474 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1475
1476 if (castOp.isDynamicStride(i))
1477 desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1478 else
1479 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1480 }
1481 *descriptor = desc;
1482 return success();
1483 }
1484};
1485
1486struct MemRefReshapeOpLowering
1487 : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1488 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
1489
1490 LogicalResult
1491 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1492 ConversionPatternRewriter &rewriter) const override {
1493 Type srcType = reshapeOp.getSource().getType();
1494
1495 Value descriptor;
1496 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1497 adaptor, &descriptor)))
1498 return failure();
1499 rewriter.replaceOp(reshapeOp, {descriptor});
1500 return success();
1501 }
1502
1503private:
1504 LogicalResult
1505 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1506 Type srcType, memref::ReshapeOp reshapeOp,
1507 memref::ReshapeOp::Adaptor adaptor,
1508 Value *descriptor) const {
1509 auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
1510 if (shapeMemRefType.hasStaticShape()) {
1511 MemRefType targetMemRefType =
1512 cast<MemRefType>(reshapeOp.getResult().getType());
1513 auto llvmTargetDescriptorTy =
1514 typeConverter->convertType<LLVM::LLVMStructType>(targetMemRefType);
1515 if (!llvmTargetDescriptorTy)
1516 return failure();
1517
1518 // Create descriptor.
1519 Location loc = reshapeOp.getLoc();
1520 auto desc =
1521 MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1522
1523 // Set allocated and aligned pointers.
1524 Value allocatedPtr, alignedPtr;
1525 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1526 reshapeOp.getSource(), adaptor.getSource(),
1527 &allocatedPtr, &alignedPtr);
1528 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1529 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1530
1531 // Extract the offset and strides from the type.
1532 int64_t offset;
1533 SmallVector<int64_t> strides;
1534 if (failed(targetMemRefType.getStridesAndOffset(strides, offset)))
1535 return rewriter.notifyMatchFailure(
1536 reshapeOp, "failed to get stride and offset exprs");
1537
1538 if (!isStaticStrideOrOffset(offset))
1539 return rewriter.notifyMatchFailure(reshapeOp,
1540 "dynamic offset is unsupported");
1541
1542 desc.setConstantOffset(rewriter, loc, offset);
1543
1544 assert(targetMemRefType.getLayout().isIdentity() &&
1545 "Identity layout map is a precondition of a valid reshape op");
1546
1547 Type indexType = getIndexType();
1548 Value stride = nullptr;
1549 int64_t targetRank = targetMemRefType.getRank();
1550 for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1551 if (ShapedType::isStatic(strides[i])) {
1552 // If the stride for this dimension is dynamic, then use the product
1553 // of the sizes of the inner dimensions.
1554 stride =
1555 createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
1556 } else if (!stride) {
1557 // `stride` is null only in the first iteration of the loop. However,
1558 // since the target memref has an identity layout, we can safely set
1559 // the innermost stride to 1.
1560 stride = createIndexAttrConstant(rewriter, loc, indexType, 1);
1561 }
1562
1563 Value dimSize;
1564 // If the size of this dimension is dynamic, then load it at runtime
1565 // from the shape operand.
1566 if (!targetMemRefType.isDynamicDim(i)) {
1567 dimSize = createIndexAttrConstant(rewriter, loc, indexType,
1568 targetMemRefType.getDimSize(i));
1569 } else {
1570 Value shapeOp = reshapeOp.getShape();
1571 Value index = createIndexAttrConstant(rewriter, loc, indexType, i);
1572 dimSize = memref::LoadOp::create(rewriter, loc, shapeOp, index);
1573 Type indexType = getIndexType();
1574 if (dimSize.getType() != indexType)
1575 dimSize = typeConverter->materializeTargetConversion(
1576 rewriter, loc, indexType, dimSize);
1577 assert(dimSize && "Invalid memref element type");
1578 }
1579
1580 desc.setSize(rewriter, loc, i, dimSize);
1581 desc.setStride(rewriter, loc, i, stride);
1582
1583 // Prepare the stride value for the next dimension.
1584 stride = LLVM::MulOp::create(rewriter, loc, stride, dimSize);
1585 }
1586
1587 *descriptor = desc;
1588 return success();
1589 }
1590
1591 // The shape is a rank-1 tensor with unknown length.
1592 Location loc = reshapeOp.getLoc();
1593 MemRefDescriptor shapeDesc(adaptor.getShape());
1594 Value resultRank = shapeDesc.size(rewriter, loc, 0);
1595
1596 // Extract address space and element type.
1597 auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1598 unsigned addressSpace =
1599 *getTypeConverter()->getMemRefAddressSpace(targetType);
1600
1601 // Create the unranked memref descriptor that holds the ranked one. The
1602 // inner descriptor is allocated on stack.
1603 auto targetDesc = UnrankedMemRefDescriptor::poison(
1604 rewriter, loc, typeConverter->convertType(targetType));
1605 targetDesc.setRank(rewriter, loc, resultRank);
1607 rewriter, loc, *getTypeConverter(), targetDesc, addressSpace);
1608 Value underlyingDescPtr = LLVM::AllocaOp::create(
1609 rewriter, loc, getPtrType(), IntegerType::get(getContext(), 8),
1610 allocationSize);
1611 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1612
1613 // Extract pointers and offset from the source memref.
1614 Value allocatedPtr, alignedPtr, offset;
1615 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1616 reshapeOp.getSource(), adaptor.getSource(),
1617 &allocatedPtr, &alignedPtr, &offset);
1618
1619 // Set pointers and offset.
1620 auto elementPtrType =
1621 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
1622
1623 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1624 elementPtrType, allocatedPtr);
1625 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1626 underlyingDescPtr, elementPtrType,
1627 alignedPtr);
1628 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1629 underlyingDescPtr, elementPtrType,
1630 offset);
1631
1632 // Use the offset pointer as base for further addressing. Copy over the new
1633 // shape and compute strides. For this, we create a loop from rank-1 to 0.
1635 rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
1637 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1638 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1639 Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1);
1640 Value resultRankMinusOne =
1641 LLVM::SubOp::create(rewriter, loc, resultRank, oneIndex);
1642
1643 Block *initBlock = rewriter.getInsertionBlock();
1644 Type indexType = getTypeConverter()->getIndexType();
1645 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1646
1647 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1648 {indexType, indexType}, {loc, loc});
1649
1650 // Move the remaining initBlock ops to condBlock.
1651 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1652 rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1653
1654 rewriter.setInsertionPointToEnd(initBlock);
1655 LLVM::BrOp::create(rewriter, loc,
1656 ValueRange({resultRankMinusOne, oneIndex}), condBlock);
1657 rewriter.setInsertionPointToStart(condBlock);
1658 Value indexArg = condBlock->getArgument(0);
1659 Value strideArg = condBlock->getArgument(1);
1660
1661 Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0);
1662 Value pred = LLVM::ICmpOp::create(
1663 rewriter, loc, IntegerType::get(rewriter.getContext(), 1),
1664 LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1665
1666 Block *bodyBlock =
1667 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1668 rewriter.setInsertionPointToStart(bodyBlock);
1669
1670 // Copy size from shape to descriptor.
1671 auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1672 Value sizeLoadGep = LLVM::GEPOp::create(
1673 rewriter, loc, llvmIndexPtrType,
1674 typeConverter->convertType(shapeMemRefType.getElementType()),
1675 shapeOperandPtr, indexArg);
1676 Value size = LLVM::LoadOp::create(rewriter, loc, indexType, sizeLoadGep);
1677 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1678 targetSizesBase, indexArg, size);
1679
1680 // Write stride value and compute next one.
1681 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1682 targetStridesBase, indexArg, strideArg);
1683 Value nextStride = LLVM::MulOp::create(rewriter, loc, strideArg, size);
1684
1685 // Decrement loop counter and branch back.
1686 Value decrement = LLVM::SubOp::create(rewriter, loc, indexArg, oneIndex);
1687 LLVM::BrOp::create(rewriter, loc, ValueRange({decrement, nextStride}),
1688 condBlock);
1689
1690 Block *remainder =
1691 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1692
1693 // Hook up the cond exit to the remainder.
1694 rewriter.setInsertionPointToEnd(condBlock);
1695 LLVM::CondBrOp::create(rewriter, loc, pred, bodyBlock, ValueRange(),
1696 remainder, ValueRange());
1697
1698 // Reset position to beginning of new remainder block.
1699 rewriter.setInsertionPointToStart(remainder);
1700
1701 *descriptor = targetDesc;
1702 return success();
1703 }
1704};
1705
1706/// RessociatingReshapeOp must be expanded before we reach this stage.
1707/// Report that information.
1708template <typename ReshapeOp>
1709class ReassociatingReshapeOpConversion
1710 : public ConvertOpToLLVMPattern<ReshapeOp> {
1711public:
1713 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1714
1715 LogicalResult
1716 matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1717 ConversionPatternRewriter &rewriter) const override {
1718 return rewriter.notifyMatchFailure(
1719 reshapeOp,
1720 "reassociation operations should have been expanded beforehand");
1721 }
1722};
1723
1724/// Subviews must be expanded before we reach this stage.
1725/// Report that information.
1726struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1727 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1728
1729 LogicalResult
1730 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1731 ConversionPatternRewriter &rewriter) const override {
1732 return rewriter.notifyMatchFailure(
1733 subViewOp, "subview operations should have been expanded beforehand");
1734 }
1735};
1736
1737/// Conversion pattern that transforms a transpose op into:
1738/// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
1739/// 2. A load of the ViewDescriptor from the pointer allocated in 1.
1740/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1741/// and stride. Size and stride are permutations of the original values.
1742/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1743/// The transpose op is replaced by the alloca'ed pointer.
1744class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1745public:
1746 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1747
1748 LogicalResult
1749 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1750 ConversionPatternRewriter &rewriter) const override {
1751 auto loc = transposeOp.getLoc();
1752 MemRefDescriptor viewMemRef(adaptor.getIn());
1753
1754 // No permutation, early exit.
1755 if (transposeOp.getPermutation().isIdentity())
1756 return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1757
1758 auto targetMemRef = MemRefDescriptor::poison(
1759 rewriter, loc,
1760 typeConverter->convertType(transposeOp.getIn().getType()));
1761
1762 // Copy the base and aligned pointers from the old descriptor to the new
1763 // one.
1764 targetMemRef.setAllocatedPtr(rewriter, loc,
1765 viewMemRef.allocatedPtr(rewriter, loc));
1766 targetMemRef.setAlignedPtr(rewriter, loc,
1767 viewMemRef.alignedPtr(rewriter, loc));
1768
1769 // Copy the offset pointer from the old descriptor to the new one.
1770 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1771
1772 // Iterate over the dimensions and apply size/stride permutation:
1773 // When enumerating the results of the permutation map, the enumeration
1774 // index is the index into the target dimensions and the DimExpr points to
1775 // the dimension of the source memref.
1776 for (const auto &en :
1777 llvm::enumerate(transposeOp.getPermutation().getResults())) {
1778 int targetPos = en.index();
1779 int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
1780 targetMemRef.setSize(rewriter, loc, targetPos,
1781 viewMemRef.size(rewriter, loc, sourcePos));
1782 targetMemRef.setStride(rewriter, loc, targetPos,
1783 viewMemRef.stride(rewriter, loc, sourcePos));
1784 }
1785
1786 rewriter.replaceOp(transposeOp, {targetMemRef});
1787 return success();
1788 }
1789};
1790
1791/// Conversion pattern that transforms an op into:
1792/// 1. An `llvm.mlir.undef` operation to create a memref descriptor
1793/// 2. Updates to the descriptor to introduce the data ptr, offset, size
1794/// and stride.
1795/// The view op is replaced by the descriptor.
1796struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1798
1799 // Build and return the value for the idx^th shape dimension, either by
1800 // returning the constant shape dimension or counting the proper dynamic size.
1801 Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1802 ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx,
1803 Type indexType) const {
1804 assert(idx < shape.size());
1805 if (ShapedType::isStatic(shape[idx]))
1806 return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]);
1807 // Count the number of dynamic dims in range [0, idx]
1808 unsigned nDynamic =
1809 llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1810 return dynamicSizes[nDynamic];
1811 }
1812
1813 // Build and return the idx^th stride, either by returning the constant stride
1814 // or by computing the dynamic stride from the current `runningStride` and
1815 // `nextSize`. The caller should keep a running stride and update it with the
1816 // result returned by this function.
1817 Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1818 ArrayRef<int64_t> strides, Value nextSize,
1819 Value runningStride, unsigned idx, Type indexType) const {
1820 assert(idx < strides.size());
1821 if (ShapedType::isStatic(strides[idx]))
1822 return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]);
1823 if (nextSize)
1824 return runningStride
1825 ? LLVM::MulOp::create(rewriter, loc, runningStride, nextSize)
1826 : nextSize;
1827 assert(!runningStride);
1828 return createIndexAttrConstant(rewriter, loc, indexType, 1);
1829 }
1830
1831 LogicalResult
1832 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1833 ConversionPatternRewriter &rewriter) const override {
1834 auto loc = viewOp.getLoc();
1835
1836 auto viewMemRefType = viewOp.getType();
1837 auto targetElementTy =
1838 typeConverter->convertType(viewMemRefType.getElementType());
1839 auto targetDescTy = typeConverter->convertType(viewMemRefType);
1840 if (!targetDescTy || !targetElementTy ||
1841 !LLVM::isCompatibleType(targetElementTy) ||
1842 !LLVM::isCompatibleType(targetDescTy))
1843 return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1844 failure();
1845
1846 int64_t offset;
1848 auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset);
1849 if (failed(successStrides))
1850 return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1851 assert(offset == 0 && "expected offset to be 0");
1852
1853 // Target memref must be contiguous in memory (innermost stride is 1), or
1854 // empty (special case when at least one of the memref dimensions is 0).
1855 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1856 return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1857 failure();
1858
1859 // Create the descriptor.
1860 MemRefDescriptor sourceMemRef(adaptor.getSource());
1861 auto targetMemRef = MemRefDescriptor::poison(rewriter, loc, targetDescTy);
1862
1863 // Field 1: Copy the allocated pointer, used for malloc/free.
1864 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1865 auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1866 targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
1867
1868 // Field 2: Copy the actual aligned pointer to payload.
1869 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1870 alignedPtr = LLVM::GEPOp::create(
1871 rewriter, loc, alignedPtr.getType(),
1872 typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
1873 adaptor.getByteShift());
1874
1875 targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
1876
1877 Type indexType = getIndexType();
1878 // Field 3: The offset in the resulting type must be 0. This is
1879 // because of the type change: an offset on srcType* may not be
1880 // expressible as an offset on dstType*.
1881 targetMemRef.setOffset(
1882 rewriter, loc,
1883 createIndexAttrConstant(rewriter, loc, indexType, offset));
1884
1885 // Early exit for 0-D corner case.
1886 if (viewMemRefType.getRank() == 0)
1887 return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1888
1889 // Fields 4 and 5: Update sizes and strides.
1890 Value stride = nullptr, nextSize = nullptr;
1891 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1892 // Update size.
1893 Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1894 adaptor.getSizes(), i, indexType);
1895 targetMemRef.setSize(rewriter, loc, i, size);
1896 // Update stride.
1897 stride =
1898 getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
1899 targetMemRef.setStride(rewriter, loc, i, stride);
1900 nextSize = size;
1901 }
1902
1903 rewriter.replaceOp(viewOp, {targetMemRef});
1904 return success();
1905 }
1906};
1907
1908//===----------------------------------------------------------------------===//
1909// AtomicRMWOpLowering
1910//===----------------------------------------------------------------------===//
1911
1912/// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1913/// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1914static std::optional<LLVM::AtomicBinOp>
1915matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1916 switch (atomicOp.getKind()) {
1917 case arith::AtomicRMWKind::addf:
1918 return LLVM::AtomicBinOp::fadd;
1919 case arith::AtomicRMWKind::addi:
1920 return LLVM::AtomicBinOp::add;
1921 case arith::AtomicRMWKind::assign:
1922 return LLVM::AtomicBinOp::xchg;
1923 case arith::AtomicRMWKind::maximumf:
1924 // TODO: remove this by end of 2025.
1925 LDBG() << "the lowering of memref.atomicrmw maximumf changed "
1926 "from fmax to fmaximum, expect more NaNs";
1927 return LLVM::AtomicBinOp::fmaximum;
1928 case arith::AtomicRMWKind::maxnumf:
1929 return LLVM::AtomicBinOp::fmax;
1930 case arith::AtomicRMWKind::maxs:
1931 return LLVM::AtomicBinOp::max;
1932 case arith::AtomicRMWKind::maxu:
1933 return LLVM::AtomicBinOp::umax;
1934 case arith::AtomicRMWKind::minimumf:
1935 // TODO: remove this by end of 2025.
1936 LDBG() << "the lowering of memref.atomicrmw minimum changed "
1937 "from fmin to fminimum, expect more NaNs";
1938 return LLVM::AtomicBinOp::fminimum;
1939 case arith::AtomicRMWKind::minnumf:
1940 return LLVM::AtomicBinOp::fmin;
1941 case arith::AtomicRMWKind::mins:
1942 return LLVM::AtomicBinOp::min;
1943 case arith::AtomicRMWKind::minu:
1944 return LLVM::AtomicBinOp::umin;
1945 case arith::AtomicRMWKind::ori:
1946 return LLVM::AtomicBinOp::_or;
1947 case arith::AtomicRMWKind::xori:
1948 return LLVM::AtomicBinOp::_xor;
1949 case arith::AtomicRMWKind::andi:
1950 return LLVM::AtomicBinOp::_and;
1951 default:
1952 return std::nullopt;
1953 }
1954 llvm_unreachable("Invalid AtomicRMWKind");
1955}
1956
1957struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1958 using Base::Base;
1959
1960 LogicalResult
1961 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1962 ConversionPatternRewriter &rewriter) const override {
1963 auto maybeKind = matchSimpleAtomicOp(atomicOp);
1964 if (!maybeKind)
1965 return failure();
1966 auto memRefType = atomicOp.getMemRefType();
1967 SmallVector<int64_t> strides;
1968 int64_t offset;
1969 if (failed(memRefType.getStridesAndOffset(strides, offset)))
1970 return failure();
1971 auto dataPtr =
1972 getStridedElementPtr(rewriter, atomicOp.getLoc(), memRefType,
1973 adaptor.getMemref(), adaptor.getIndices());
1974 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1975 atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1976 LLVM::AtomicOrdering::acq_rel);
1977 return success();
1978 }
1979};
1980
1981/// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index.
1982class ConvertExtractAlignedPointerAsIndex
1983 : public ConvertOpToLLVMPattern<memref::ExtractAlignedPointerAsIndexOp> {
1984public:
1986 memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern;
1987
1988 LogicalResult
1989 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1990 OpAdaptor adaptor,
1991 ConversionPatternRewriter &rewriter) const override {
1992 BaseMemRefType sourceTy = extractOp.getSource().getType();
1993
1994 Value alignedPtr;
1995 if (sourceTy.hasRank()) {
1996 MemRefDescriptor desc(adaptor.getSource());
1997 alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc());
1998 } else {
1999 auto elementPtrTy = LLVM::LLVMPointerType::get(
2000 rewriter.getContext(), sourceTy.getMemorySpaceAsInt());
2001
2002 UnrankedMemRefDescriptor desc(adaptor.getSource());
2003 Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc());
2004
2006 rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
2007 elementPtrTy);
2008 }
2009
2010 rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
2011 extractOp, getTypeConverter()->getIndexType(), alignedPtr);
2012 return success();
2013 }
2014};
2015
2016/// Materialize the MemRef descriptor represented by the results of
2017/// ExtractStridedMetadataOp.
2018class ExtractStridedMetadataOpLowering
2019 : public ConvertOpToLLVMPattern<memref::ExtractStridedMetadataOp> {
2020public:
2022 memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern;
2023
2024 LogicalResult
2025 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
2026 OpAdaptor adaptor,
2027 ConversionPatternRewriter &rewriter) const override {
2028
2029 if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
2030 return failure();
2031
2032 // Create the descriptor.
2033 MemRefDescriptor sourceMemRef(adaptor.getSource());
2034 Location loc = extractStridedMetadataOp.getLoc();
2035 Value source = extractStridedMetadataOp.getSource();
2036
2037 auto sourceMemRefType = cast<MemRefType>(source.getType());
2038 int64_t rank = sourceMemRefType.getRank();
2039 SmallVector<Value> results;
2040 results.reserve(2 + rank * 2);
2041
2042 // Base buffer.
2043 Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc);
2044 Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
2046 rewriter, loc, *getTypeConverter(),
2047 cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
2048 baseBuffer, alignedBuffer);
2049 results.push_back((Value)dstMemRef);
2050
2051 // Offset.
2052 results.push_back(sourceMemRef.offset(rewriter, loc));
2053
2054 // Sizes.
2055 for (unsigned i = 0; i < rank; ++i)
2056 results.push_back(sourceMemRef.size(rewriter, loc, i));
2057 // Strides.
2058 for (unsigned i = 0; i < rank; ++i)
2059 results.push_back(sourceMemRef.stride(rewriter, loc, i));
2060
2061 rewriter.replaceOp(extractStridedMetadataOp, results);
2062 return success();
2063 }
2064};
2065
2066} // namespace
2067
2069 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
2070 SymbolTableCollection *symbolTables) {
2071 // clang-format off
2072 patterns.add<
2073 AllocaOpLowering,
2074 AllocaScopeOpLowering,
2075 AssumeAlignmentOpLowering,
2076 AtomicRMWOpLowering,
2077 ConvertExtractAlignedPointerAsIndex,
2078 DimOpLowering,
2079 DistinctObjectsOpLowering,
2080 ExtractStridedMetadataOpLowering,
2081 GenericAtomicRMWOpLowering,
2082 GetGlobalMemrefOpLowering,
2083 LoadOpLowering,
2084 MemRefCastOpLowering,
2085 MemRefReinterpretCastOpLowering,
2086 MemRefReshapeOpLowering,
2087 MemorySpaceCastOpLowering,
2088 PrefetchOpLowering,
2089 RankOpLowering,
2090 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
2091 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
2092 StoreOpLowering,
2093 SubViewOpLowering,
2095 ViewOpLowering>(converter);
2096 // clang-format on
2097 patterns.add<GlobalMemrefOpLowering, MemRefCopyOpLowering>(converter,
2098 symbolTables);
2099 auto allocLowering = converter.getOptions().allocLowering;
2101 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter,
2102 symbolTables);
2103 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
2104 patterns.add<AllocOpLowering, DeallocOpLowering>(converter, symbolTables);
2105}
2106
2107namespace {
2108struct FinalizeMemRefToLLVMConversionPass
2110 FinalizeMemRefToLLVMConversionPass> {
2111 using FinalizeMemRefToLLVMConversionPassBase::
2112 FinalizeMemRefToLLVMConversionPassBase;
2113
2114 void runOnOperation() override {
2115 Operation *op = getOperation();
2116 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
2118 dataLayoutAnalysis.getAtOrAbove(op));
2119 options.allocLowering =
2122
2123 options.useGenericFunctions = useGenericFunctions;
2124
2125 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
2126 options.overrideIndexBitwidth(indexBitwidth);
2127
2128 LLVMTypeConverter typeConverter(&getContext(), options,
2129 &dataLayoutAnalysis);
2130 RewritePatternSet patterns(&getContext());
2131 SymbolTableCollection symbolTables;
2132 populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns,
2133 &symbolTables);
2135 target.addLegalOp<func::FuncOp>();
2136 if (failed(applyPartialConversion(op, target, std::move(patterns))))
2137 signalPassFailure();
2138 }
2139};
2140
2141/// Implement the interface to convert MemRef to LLVM.
2142struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
2143 MemRefToLLVMDialectInterface(Dialect *dialect)
2144 : ConvertToLLVMPatternInterface(dialect) {}
2145
2146 void loadDependentDialects(MLIRContext *context) const final {
2147 context->loadDialect<LLVM::LLVMDialect>();
2148 }
2149
2150 /// Hook for derived dialect interface to provide conversion patterns
2151 /// and mark dialect legal for the conversion target.
2152 void populateConvertToLLVMConversionPatterns(
2153 ConversionTarget &target, LLVMTypeConverter &typeConverter,
2154 RewritePatternSet &patterns) const final {
2155 populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
2156 }
2157};
2158
2159} // namespace
2160
2162 registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
2163 dialect->addInterfaces<MemRefToLLVMDialectInterface>();
2164 });
2165}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:150
BlockArgument getArgument(unsigned i)
Definition Block.h:139
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:233
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:58
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) const
Promote the LLVM struct representation of one MemRef descriptor to stack and use pointer to struct to...
const LowerToLLVMOptions & getOptions() const
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
const DataLayoutAnalysis * getDataLayoutAnalysis() const
Returns the data layout analysis to query during conversion.
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Options to control the LLVM lowering.
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value bufferPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type)
Builds IR for getting the start address of the buffer represented by this memref: memref....
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
Value offset(OpBuilder &builder, Location loc)
Builds IR extracting the offset from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:376
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:274
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class represents a collection of SymbolTables.
virtual SymbolTable & getSymbolTable(Operation *op)
Lookup, or create, a symbol table for an operation.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
static void setOffset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
TODO: The following accessors don't take alignment rules between elements of the descriptor struct in...
static Value computeSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, UnrankedMemRefDescriptor desc, unsigned addressSpace)
Builds and returns IR computing the size in bytes (suitable for opaque allocation).
void setRank(OpBuilder &builder, Location loc, Value value)
Builds IR setting the rank in the descriptor.
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
static void setSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void setAlignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
Value rank(OpBuilder &builder, Location loc) const
Builds IR extracting the rank from the descriptor.
static Value offsetBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR for getting the pointer to the offset's location.
static Value offset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the offset from the descriptor.
static Value strideBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value)
Builds IR setting ranked memref descriptor ptr.
static void setStride(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
static Value sizeBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the pointer to the first element of the size array.
static Value alignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the aligned pointer from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType, Type unrankedDescriptorType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
Include the generated interface declarations.
void registerConvertMemRefToLLVMInterface(DialectRegistry &registry)
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.