MLIR 22.0.0git
MemRefToLLVM.cpp
Go to the documentation of this file.
1//===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
23#include "mlir/IR/AffineMap.h"
25#include "mlir/IR/IRMapping.h"
26#include "mlir/Pass/Pass.h"
27#include "llvm/Support/DebugLog.h"
28#include "llvm/Support/MathExtras.h"
29
30#include <optional>
31
32#define DEBUG_TYPE "memref-to-llvm"
33
34namespace mlir {
35#define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
36#include "mlir/Conversion/Passes.h.inc"
37} // namespace mlir
38
39using namespace mlir;
40
41static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags =
42 LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw;
43
44namespace {
45
46static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
47 return ShapedType::isStatic(strideOrOffset);
48}
49
50static FailureOr<LLVM::LLVMFuncOp>
51getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
52 Operation *module, SymbolTableCollection *symbolTables) {
53 bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
54
55 if (useGenericFn)
56 return LLVM::lookupOrCreateGenericFreeFn(b, module, symbolTables);
57
58 return LLVM::lookupOrCreateFreeFn(b, module, symbolTables);
59}
60
61static FailureOr<LLVM::LLVMFuncOp>
62getNotalignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
63 Operation *module, Type indexType,
64 SymbolTableCollection *symbolTables) {
65 bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
66 if (useGenericFn)
67 return LLVM::lookupOrCreateGenericAllocFn(b, module, indexType,
68 symbolTables);
69
70 return LLVM::lookupOrCreateMallocFn(b, module, indexType, symbolTables);
71}
72
73static FailureOr<LLVM::LLVMFuncOp>
74getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
75 Operation *module, Type indexType,
76 SymbolTableCollection *symbolTables) {
77 bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
78
79 if (useGenericFn)
80 return LLVM::lookupOrCreateGenericAlignedAllocFn(b, module, indexType,
81 symbolTables);
82
83 return LLVM::lookupOrCreateAlignedAllocFn(b, module, indexType, symbolTables);
84}
85
86/// Computes the aligned value for 'input' as follows:
87/// bumped = input + alignement - 1
88/// aligned = bumped - bumped % alignment
89static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
90 Value input, Value alignment) {
91 Value one = LLVM::ConstantOp::create(rewriter, loc, alignment.getType(),
92 rewriter.getIndexAttr(1));
93 Value bump = LLVM::SubOp::create(rewriter, loc, alignment, one);
94 Value bumped = LLVM::AddOp::create(rewriter, loc, input, bump);
95 Value mod = LLVM::URemOp::create(rewriter, loc, bumped, alignment);
96 return LLVM::SubOp::create(rewriter, loc, bumped, mod);
97}
98
99/// Computes the byte size for the MemRef element type.
100static unsigned getMemRefEltSizeInBytes(const LLVMTypeConverter *typeConverter,
101 MemRefType memRefType, Operation *op,
102 const DataLayout *defaultLayout) {
103 const DataLayout *layout = defaultLayout;
104 if (const DataLayoutAnalysis *analysis =
105 typeConverter->getDataLayoutAnalysis()) {
106 layout = &analysis->getAbove(op);
107 }
108 Type elementType = memRefType.getElementType();
109 if (auto memRefElementType = dyn_cast<MemRefType>(elementType))
110 return typeConverter->getMemRefDescriptorSize(memRefElementType, *layout);
111 if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
112 return typeConverter->getUnrankedMemRefDescriptorSize(memRefElementType,
113 *layout);
114 return layout->getTypeSize(elementType);
115}
116
117static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
118 Location loc, Value allocatedPtr,
119 MemRefType memRefType, Type elementPtrType,
120 const LLVMTypeConverter &typeConverter) {
121 auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType());
122 FailureOr<unsigned> maybeMemrefAddrSpace =
123 typeConverter.getMemRefAddressSpace(memRefType);
124 assert(succeeded(maybeMemrefAddrSpace) && "unsupported address space");
125 unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
126 if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
127 allocatedPtr = LLVM::AddrSpaceCastOp::create(
128 rewriter, loc,
129 LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
130 allocatedPtr);
131 return allocatedPtr;
132}
133
134class AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
135 SymbolTableCollection *symbolTables = nullptr;
136
137public:
138 explicit AllocOpLowering(const LLVMTypeConverter &typeConverter,
139 SymbolTableCollection *symbolTables = nullptr,
140 PatternBenefit benefit = 1)
141 : ConvertOpToLLVMPattern<memref::AllocOp>(typeConverter, benefit),
142 symbolTables(symbolTables) {}
143
144 LogicalResult
145 matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
146 ConversionPatternRewriter &rewriter) const override {
147 auto loc = op.getLoc();
148 MemRefType memRefType = op.getType();
149 if (!isConvertibleAndHasIdentityMaps(memRefType))
150 return rewriter.notifyMatchFailure(op, "incompatible memref type");
151
152 // Get or insert alloc function into the module.
153 FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
154 getNotalignedAllocFn(rewriter, getTypeConverter(),
155 op->getParentWithTrait<OpTrait::SymbolTable>(),
156 getIndexType(), symbolTables);
157 if (failed(allocFuncOp))
158 return failure();
159
160 // Get actual sizes of the memref as values: static sizes are constant
161 // values and dynamic sizes are passed to 'alloc' as operands. In case of
162 // zero-dimensional memref, assume a scalar (size 1).
164 SmallVector<Value, 4> strides;
165 Value sizeBytes;
166
167 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
168 rewriter, sizes, strides, sizeBytes, true);
169
170 Value alignment = getAlignment(rewriter, loc, op);
171 if (alignment) {
172 // Adjust the allocation size to consider alignment.
173 sizeBytes = LLVM::AddOp::create(rewriter, loc, sizeBytes, alignment);
174 }
175
176 // Allocate the underlying buffer.
177 Type elementPtrType = this->getElementPtrType(memRefType);
178 assert(elementPtrType && "could not compute element ptr type");
179 auto results =
180 LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(), sizeBytes);
181
182 Value allocatedPtr =
183 castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
184 elementPtrType, *getTypeConverter());
185 Value alignedPtr = allocatedPtr;
186 if (alignment) {
187 // Compute the aligned pointer.
188 Value allocatedInt =
189 LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), allocatedPtr);
190 Value alignmentInt =
191 createAligned(rewriter, loc, allocatedInt, alignment);
192 alignedPtr =
193 LLVM::IntToPtrOp::create(rewriter, loc, elementPtrType, alignmentInt);
194 }
195
196 // Create the MemRef descriptor.
197 auto memRefDescriptor = this->createMemRefDescriptor(
198 loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
199
200 // Return the final value of the descriptor.
201 rewriter.replaceOp(op, {memRefDescriptor});
202 return success();
203 }
204
205 /// Computes the alignment for the given memory allocation op.
206 template <typename OpType>
207 Value getAlignment(ConversionPatternRewriter &rewriter, Location loc,
208 OpType op) const {
209 MemRefType memRefType = op.getType();
210 Value alignment;
211 if (auto alignmentAttr = op.getAlignment()) {
212 Type indexType = getIndexType();
213 alignment =
214 createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr);
215 } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
216 // In the case where no alignment is specified, we may want to override
217 // `malloc's` behavior. `malloc` typically aligns at the size of the
218 // biggest scalar on a target HW. For non-scalars, use the natural
219 // alignment of the LLVM type given by the LLVM DataLayout.
220 alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
221 }
222 return alignment;
223 }
224};
225
226class AlignedAllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
227 SymbolTableCollection *symbolTables = nullptr;
228
229public:
230 explicit AlignedAllocOpLowering(const LLVMTypeConverter &typeConverter,
231 SymbolTableCollection *symbolTables = nullptr,
232 PatternBenefit benefit = 1)
233 : ConvertOpToLLVMPattern<memref::AllocOp>(typeConverter, benefit),
234 symbolTables(symbolTables) {}
235
236 LogicalResult
237 matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
238 ConversionPatternRewriter &rewriter) const override {
239 auto loc = op.getLoc();
240 MemRefType memRefType = op.getType();
241 if (!isConvertibleAndHasIdentityMaps(memRefType))
242 return rewriter.notifyMatchFailure(op, "incompatible memref type");
243
244 // Get or insert alloc function into module.
245 FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
246 getAlignedAllocFn(rewriter, getTypeConverter(),
247 op->getParentWithTrait<OpTrait::SymbolTable>(),
248 getIndexType(), symbolTables);
249 if (failed(allocFuncOp))
250 return failure();
251
252 // Get actual sizes of the memref as values: static sizes are constant
253 // values and dynamic sizes are passed to 'alloc' as operands. In case of
254 // zero-dimensional memref, assume a scalar (size 1).
256 SmallVector<Value, 4> strides;
257 Value sizeBytes;
258
259 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
260 rewriter, sizes, strides, sizeBytes, !false);
261
262 int64_t alignment = alignedAllocationGetAlignment(op, &defaultLayout);
263
264 Value allocAlignment =
265 createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
266
267 // Function aligned_alloc requires size to be a multiple of alignment; we
268 // pad the size to the next multiple if necessary.
269 if (!isMemRefSizeMultipleOf(memRefType, alignment, op, &defaultLayout))
270 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
271
272 Type elementPtrType = this->getElementPtrType(memRefType);
273 auto results =
274 LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(),
275 ValueRange({allocAlignment, sizeBytes}));
276
277 Value ptr =
278 castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
279 elementPtrType, *getTypeConverter());
280
281 // Create the MemRef descriptor.
282 auto memRefDescriptor = this->createMemRefDescriptor(
283 loc, memRefType, ptr, ptr, sizes, strides, rewriter);
284
285 // Return the final value of the descriptor.
286 rewriter.replaceOp(op, {memRefDescriptor});
287 return success();
288 }
289
290 /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
291 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
292
293 /// Computes the alignment for aligned_alloc used to allocate the buffer for
294 /// the memory allocation op.
295 ///
296 /// Aligned_alloc requires the allocation size to be a power of two, and the
297 /// allocation size to be a multiple of the alignment.
298 int64_t alignedAllocationGetAlignment(memref::AllocOp op,
299 const DataLayout *defaultLayout) const {
300 if (std::optional<uint64_t> alignment = op.getAlignment())
301 return *alignment;
302
303 // Whenever we don't have alignment set, we will use an alignment
304 // consistent with the element type; since the allocation size has to be a
305 // power of two, we will bump to the next power of two if it isn't.
306 unsigned eltSizeBytes = getMemRefEltSizeInBytes(
307 getTypeConverter(), op.getType(), op, defaultLayout);
308 return std::max(kMinAlignedAllocAlignment,
309 llvm::PowerOf2Ceil(eltSizeBytes));
310 }
311
312 /// Returns true if the memref size in bytes is known to be a multiple of
313 /// factor.
314 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, Operation *op,
315 const DataLayout *defaultLayout) const {
316 uint64_t sizeDivisor =
317 getMemRefEltSizeInBytes(getTypeConverter(), type, op, defaultLayout);
318 for (unsigned i = 0, e = type.getRank(); i < e; i++) {
319 if (type.isDynamicDim(i))
320 continue;
321 sizeDivisor = sizeDivisor * type.getDimSize(i);
322 }
323 return sizeDivisor % factor == 0;
324 }
325
326private:
327 /// Default layout to use in absence of the corresponding analysis.
328 DataLayout defaultLayout;
329};
330
331struct AllocaOpLowering : public ConvertOpToLLVMPattern<memref::AllocaOp> {
333
334 /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
335 /// is set to null for stack allocations. `accessAlignment` is set if
336 /// alignment is needed post allocation (for eg. in conjunction with malloc).
337 LogicalResult
338 matchAndRewrite(memref::AllocaOp op, OpAdaptor adaptor,
339 ConversionPatternRewriter &rewriter) const override {
340 auto loc = op.getLoc();
341 MemRefType memRefType = op.getType();
342 if (!isConvertibleAndHasIdentityMaps(memRefType))
343 return rewriter.notifyMatchFailure(op, "incompatible memref type");
344
345 // Get actual sizes of the memref as values: static sizes are constant
346 // values and dynamic sizes are passed to 'alloc' as operands. In case of
347 // zero-dimensional memref, assume a scalar (size 1).
349 SmallVector<Value, 4> strides;
350 Value size;
351
352 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
353 rewriter, sizes, strides, size, !true);
354
355 // With alloca, one gets a pointer to the element type right away.
356 // For stack allocations.
357 auto elementType =
358 typeConverter->convertType(op.getType().getElementType());
359 FailureOr<unsigned> maybeAddressSpace =
360 getTypeConverter()->getMemRefAddressSpace(op.getType());
361 assert(succeeded(maybeAddressSpace) && "unsupported address space");
362 unsigned addrSpace = *maybeAddressSpace;
363 auto elementPtrType =
364 LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
365
366 auto allocatedElementPtr =
367 LLVM::AllocaOp::create(rewriter, loc, elementPtrType, elementType, size,
368 op.getAlignment().value_or(0));
369
370 // Create the MemRef descriptor.
371 auto memRefDescriptor = this->createMemRefDescriptor(
372 loc, memRefType, allocatedElementPtr, allocatedElementPtr, sizes,
373 strides, rewriter);
374
375 // Return the final value of the descriptor.
376 rewriter.replaceOp(op, {memRefDescriptor});
377 return success();
378 }
379};
380
381struct AllocaScopeOpLowering
382 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
383 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
384
385 LogicalResult
386 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
387 ConversionPatternRewriter &rewriter) const override {
388 OpBuilder::InsertionGuard guard(rewriter);
389 Location loc = allocaScopeOp.getLoc();
390
391 // Split the current block before the AllocaScopeOp to create the inlining
392 // point.
393 auto *currentBlock = rewriter.getInsertionBlock();
394 auto *remainingOpsBlock =
395 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
396 Block *continueBlock;
397 if (allocaScopeOp.getNumResults() == 0) {
398 continueBlock = remainingOpsBlock;
399 } else {
400 continueBlock = rewriter.createBlock(
401 remainingOpsBlock, allocaScopeOp.getResultTypes(),
402 SmallVector<Location>(allocaScopeOp->getNumResults(),
403 allocaScopeOp.getLoc()));
404 LLVM::BrOp::create(rewriter, loc, ValueRange(), remainingOpsBlock);
405 }
406
407 // Inline body region.
408 Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
409 Block *afterBody = &allocaScopeOp.getBodyRegion().back();
410 rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
411
412 // Save stack and then branch into the body of the region.
413 rewriter.setInsertionPointToEnd(currentBlock);
414 auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType());
415 LLVM::BrOp::create(rewriter, loc, ValueRange(), beforeBody);
416
417 // Replace the alloca_scope return with a branch that jumps out of the body.
418 // Stack restore before leaving the body region.
419 rewriter.setInsertionPointToEnd(afterBody);
420 auto returnOp =
421 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
422 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
423 returnOp, returnOp.getResults(), continueBlock);
424
425 // Insert stack restore before jumping out the body of the region.
426 rewriter.setInsertionPoint(branchOp);
427 LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp);
428
429 // Replace the op with values return from the body region.
430 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
431
432 return success();
433 }
434};
435
436struct AssumeAlignmentOpLowering
437 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
439 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
440 explicit AssumeAlignmentOpLowering(const LLVMTypeConverter &converter)
442
443 LogicalResult
444 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
445 ConversionPatternRewriter &rewriter) const override {
446 Value memref = adaptor.getMemref();
447 unsigned alignment = op.getAlignment();
448 auto loc = op.getLoc();
449
450 auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
451 Value ptr = getStridedElementPtr(rewriter, loc, srcMemRefType, memref,
452 /*indices=*/{});
453
454 // Emit llvm.assume(true) ["align"(memref, alignment)].
455 // This is more direct than ptrtoint-based checks, is explicitly supported,
456 // and works with non-integral address spaces.
457 Value trueCond =
458 LLVM::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(true));
459 Value alignmentConst =
460 createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
461 LLVM::AssumeOp::create(rewriter, loc, trueCond, LLVM::AssumeAlignTag(), ptr,
462 alignmentConst);
463 rewriter.replaceOp(op, memref);
464 return success();
465 }
466};
467
468struct DistinctObjectsOpLowering
469 : public ConvertOpToLLVMPattern<memref::DistinctObjectsOp> {
471 memref::DistinctObjectsOp>::ConvertOpToLLVMPattern;
472 explicit DistinctObjectsOpLowering(const LLVMTypeConverter &converter)
474
475 LogicalResult
476 matchAndRewrite(memref::DistinctObjectsOp op, OpAdaptor adaptor,
477 ConversionPatternRewriter &rewriter) const override {
478 ValueRange operands = adaptor.getOperands();
479 if (operands.size() <= 1) {
480 // Fast path.
481 rewriter.replaceOp(op, operands);
482 return success();
483 }
484
485 Location loc = op.getLoc();
487 for (auto [origOperand, newOperand] :
488 llvm::zip_equal(op.getOperands(), operands)) {
489 auto memrefType = cast<MemRefType>(origOperand.getType());
490 MemRefDescriptor memRefDescriptor(newOperand);
491 Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
492 memrefType);
493 ptrs.push_back(ptr);
494 }
495
496 auto cond =
497 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), 1);
498 // Generate separate_storage assumptions for each pair of pointers.
499 for (auto i : llvm::seq<size_t>(ptrs.size() - 1)) {
500 for (auto j : llvm::seq<size_t>(i + 1, ptrs.size())) {
501 Value ptr1 = ptrs[i];
502 Value ptr2 = ptrs[j];
503 LLVM::AssumeOp::create(rewriter, loc, cond,
504 LLVM::AssumeSeparateStorageTag{}, ptr1, ptr2);
505 }
506 }
507
508 rewriter.replaceOp(op, operands);
509 return success();
510 }
511};
512
513// A `dealloc` is converted into a call to `free` on the underlying data buffer.
514// The memref descriptor being an SSA value, there is no need to clean it up
515// in any way.
516class DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
517 SymbolTableCollection *symbolTables = nullptr;
518
519public:
520 explicit DeallocOpLowering(const LLVMTypeConverter &typeConverter,
521 SymbolTableCollection *symbolTables = nullptr,
522 PatternBenefit benefit = 1)
523 : ConvertOpToLLVMPattern<memref::DeallocOp>(typeConverter, benefit),
524 symbolTables(symbolTables) {}
525
526 LogicalResult
527 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
528 ConversionPatternRewriter &rewriter) const override {
529 // Insert the `free` declaration if it is not already present.
530 FailureOr<LLVM::LLVMFuncOp> freeFunc =
531 getFreeFn(rewriter, getTypeConverter(),
532 op->getParentWithTrait<OpTrait::SymbolTable>(), symbolTables);
533 if (failed(freeFunc))
534 return failure();
535 Value allocatedPtr;
536 if (auto unrankedTy =
537 llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
538 auto elementPtrTy = LLVM::LLVMPointerType::get(
539 rewriter.getContext(), unrankedTy.getMemorySpaceAsInt());
541 rewriter, op.getLoc(),
542 UnrankedMemRefDescriptor(adaptor.getMemref())
543 .memRefDescPtr(rewriter, op.getLoc()),
544 elementPtrTy);
545 } else {
546 allocatedPtr = MemRefDescriptor(adaptor.getMemref())
547 .allocatedPtr(rewriter, op.getLoc());
548 }
549 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc.value(),
550 allocatedPtr);
551 return success();
552 }
553};
554
555// A `dim` is converted to a constant for static sizes and to an access to the
556// size stored in the memref descriptor for dynamic sizes.
557struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
559
560 LogicalResult
561 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
562 ConversionPatternRewriter &rewriter) const override {
563 Type operandType = dimOp.getSource().getType();
564 if (isa<UnrankedMemRefType>(operandType)) {
565 FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
566 operandType, dimOp, adaptor.getOperands(), rewriter);
567 if (failed(extractedSize))
568 return failure();
569 rewriter.replaceOp(dimOp, {*extractedSize});
570 return success();
571 }
572 if (isa<MemRefType>(operandType)) {
573 rewriter.replaceOp(
574 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
575 adaptor.getOperands(), rewriter)});
576 return success();
577 }
578 llvm_unreachable("expected MemRefType or UnrankedMemRefType");
579 }
580
581private:
582 FailureOr<Value>
583 extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
584 OpAdaptor adaptor,
585 ConversionPatternRewriter &rewriter) const {
586 Location loc = dimOp.getLoc();
587
588 auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
589 auto scalarMemRefType =
590 MemRefType::get({}, unrankedMemRefType.getElementType());
591 FailureOr<unsigned> maybeAddressSpace =
592 getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
593 if (failed(maybeAddressSpace)) {
594 dimOp.emitOpError("memref memory space must be convertible to an integer "
595 "address space");
596 return failure();
597 }
598 unsigned addressSpace = *maybeAddressSpace;
599
600 // Extract pointer to the underlying ranked descriptor and bitcast it to a
601 // memref<element_type> descriptor pointer to minimize the number of GEP
602 // operations.
603 UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
604 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
605
606 Type elementType = typeConverter->convertType(scalarMemRefType);
607
608 // Get pointer to offset field of memref<element_type> descriptor.
609 auto indexPtrTy =
610 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
611 Value offsetPtr =
612 LLVM::GEPOp::create(rewriter, loc, indexPtrTy, elementType,
613 underlyingRankedDesc, ArrayRef<LLVM::GEPArg>{0, 2});
614
615 // The size value that we have to extract can be obtained using GEPop with
616 // `dimOp.index() + 1` index argument.
617 Value idxPlusOne = LLVM::AddOp::create(
618 rewriter, loc,
619 createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
620 adaptor.getIndex());
621 Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy,
622 getTypeConverter()->getIndexType(),
623 offsetPtr, idxPlusOne);
624 return LLVM::LoadOp::create(rewriter, loc,
625 getTypeConverter()->getIndexType(), sizePtr)
626 .getResult();
627 }
628
629 std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
630 if (auto idx = dimOp.getConstantIndex())
631 return idx;
632
633 if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
634 return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
635
636 return std::nullopt;
637 }
638
639 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
640 OpAdaptor adaptor,
641 ConversionPatternRewriter &rewriter) const {
642 Location loc = dimOp.getLoc();
643
644 // Take advantage if index is constant.
645 MemRefType memRefType = cast<MemRefType>(operandType);
646 Type indexType = getIndexType();
647 if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
648 int64_t i = *index;
649 if (i >= 0 && i < memRefType.getRank()) {
650 if (memRefType.isDynamicDim(i)) {
651 // extract dynamic size from the memref descriptor.
652 MemRefDescriptor descriptor(adaptor.getSource());
653 return descriptor.size(rewriter, loc, i);
654 }
655 // Use constant for static size.
656 int64_t dimSize = memRefType.getDimSize(i);
657 return createIndexAttrConstant(rewriter, loc, indexType, dimSize);
658 }
659 }
660 Value index = adaptor.getIndex();
661 int64_t rank = memRefType.getRank();
662 MemRefDescriptor memrefDescriptor(adaptor.getSource());
663 return memrefDescriptor.size(rewriter, loc, index, rank);
664 }
665};
666
667/// Common base for load and store operations on MemRefs. Restricts the match
668/// to supported MemRef types. Provides functionality to emit code accessing a
669/// specific element of the underlying data buffer.
670template <typename Derived>
671struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
673 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
674 using Base = LoadStoreOpLowering<Derived>;
675};
676
677/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
678/// retried until it succeeds in atomically storing a new value into memory.
679///
680/// +---------------------------------+
681/// | <code before the AtomicRMWOp> |
682/// | <compute initial %loaded> |
683/// | cf.br loop(%loaded) |
684/// +---------------------------------+
685/// |
686/// -------| |
687/// | v v
688/// | +--------------------------------+
689/// | | loop(%loaded): |
690/// | | <body contents> |
691/// | | %pair = cmpxchg |
692/// | | %ok = %pair[0] |
693/// | | %new = %pair[1] |
694/// | | cf.cond_br %ok, end, loop(%new) |
695/// | +--------------------------------+
696/// | | |
697/// |----------- |
698/// v
699/// +--------------------------------+
700/// | end: |
701/// | <code after the AtomicRMWOp> |
702/// +--------------------------------+
703///
704struct GenericAtomicRMWOpLowering
705 : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
706 using Base::Base;
707
708 LogicalResult
709 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
710 ConversionPatternRewriter &rewriter) const override {
711 auto loc = atomicOp.getLoc();
712 Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
713
714 // Split the block into initial, loop, and ending parts.
715 auto *initBlock = rewriter.getInsertionBlock();
716 auto *loopBlock = rewriter.splitBlock(initBlock, Block::iterator(atomicOp));
717 loopBlock->addArgument(valueType, loc);
718
719 auto *endBlock =
720 rewriter.splitBlock(loopBlock, Block::iterator(atomicOp)++);
721
722 // Compute the loaded value and branch to the loop block.
723 rewriter.setInsertionPointToEnd(initBlock);
724 auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
725 auto dataPtr = getStridedElementPtr(
726 rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
727 Value init = LLVM::LoadOp::create(
728 rewriter, loc, typeConverter->convertType(memRefType.getElementType()),
729 dataPtr);
730 LLVM::BrOp::create(rewriter, loc, init, loopBlock);
731
732 // Prepare the body of the loop block.
733 rewriter.setInsertionPointToStart(loopBlock);
734
735 // Clone the GenericAtomicRMWOp region and extract the result.
736 auto loopArgument = loopBlock->getArgument(0);
737 IRMapping mapping;
738 mapping.map(atomicOp.getCurrentValue(), loopArgument);
739 Block &entryBlock = atomicOp.body().front();
740 for (auto &nestedOp : entryBlock.without_terminator()) {
741 Operation *clone = rewriter.clone(nestedOp, mapping);
742 mapping.map(nestedOp.getResults(), clone->getResults());
743 }
744
745 Value result =
746 mapping.lookupOrNull(entryBlock.getTerminator()->getOperand(0));
747 if (!result) {
748 return atomicOp.emitError("result not defined in region");
749 }
750
751 // Prepare the epilog of the loop block.
752 // Append the cmpxchg op to the end of the loop block.
753 auto successOrdering = LLVM::AtomicOrdering::acq_rel;
754 auto failureOrdering = LLVM::AtomicOrdering::monotonic;
755 auto cmpxchg =
756 LLVM::AtomicCmpXchgOp::create(rewriter, loc, dataPtr, loopArgument,
757 result, successOrdering, failureOrdering);
758 // Extract the %new_loaded and %ok values from the pair.
759 Value newLoaded = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 0);
760 Value ok = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 1);
761
762 // Conditionally branch to the end or back to the loop depending on %ok.
763 LLVM::CondBrOp::create(rewriter, loc, ok, endBlock, ArrayRef<Value>(),
764 loopBlock, newLoaded);
765
766 rewriter.setInsertionPointToEnd(endBlock);
767
768 // The 'result' of the atomic_rmw op is the newly loaded value.
769 rewriter.replaceOp(atomicOp, {newLoaded});
770
771 return success();
772 }
773};
774
775/// Returns the LLVM type of the global variable given the memref type `type`.
776static Type
777convertGlobalMemrefTypeToLLVM(MemRefType type,
778 const LLVMTypeConverter &typeConverter) {
779 // LLVM type for a global memref will be a multi-dimension array. For
780 // declarations or uninitialized global memrefs, we can potentially flatten
781 // this to a 1D array. However, for memref.global's with an initial value,
782 // we do not intend to flatten the ElementsAttribute when going from std ->
783 // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
784 Type elementType = typeConverter.convertType(type.getElementType());
785 Type arrayTy = elementType;
786 // Shape has the outermost dim at index 0, so need to walk it backwards
787 for (int64_t dim : llvm::reverse(type.getShape()))
788 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
789 return arrayTy;
790}
791
792/// GlobalMemrefOp is lowered to a LLVM Global Variable.
793class GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<memref::GlobalOp> {
794 SymbolTableCollection *symbolTables = nullptr;
795
796public:
797 explicit GlobalMemrefOpLowering(const LLVMTypeConverter &typeConverter,
798 SymbolTableCollection *symbolTables = nullptr,
799 PatternBenefit benefit = 1)
800 : ConvertOpToLLVMPattern<memref::GlobalOp>(typeConverter, benefit),
801 symbolTables(symbolTables) {}
802
803 LogicalResult
804 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
805 ConversionPatternRewriter &rewriter) const override {
806 MemRefType type = global.getType();
807 if (!isConvertibleAndHasIdentityMaps(type))
808 return failure();
809
810 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
811
812 LLVM::Linkage linkage =
813 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
814 bool isExternal = global.isExternal();
815 bool isUninitialized = global.isUninitialized();
816
817 Attribute initialValue = nullptr;
818 if (!isExternal && !isUninitialized) {
819 auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
820 initialValue = elementsAttr;
821
822 // For scalar memrefs, the global variable created is of the element type,
823 // so unpack the elements attribute to extract the value.
824 if (type.getRank() == 0)
825 initialValue = elementsAttr.getSplatValue<Attribute>();
826 }
827
828 uint64_t alignment = global.getAlignment().value_or(0);
829 FailureOr<unsigned> addressSpace =
830 getTypeConverter()->getMemRefAddressSpace(type);
831 if (failed(addressSpace))
832 return global.emitOpError(
833 "memory space cannot be converted to an integer address space");
834
835 // Remove old operation from symbol table.
836 SymbolTable *symbolTable = nullptr;
837 if (symbolTables) {
838 Operation *symbolTableOp =
840 symbolTable = &symbolTables->getSymbolTable(symbolTableOp);
841 symbolTable->remove(global);
842 }
843
844 // Create new operation.
845 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
846 global, arrayTy, global.getConstant(), linkage, global.getSymName(),
847 initialValue, alignment, *addressSpace);
848
849 // Insert new operation into symbol table.
850 if (symbolTable)
851 symbolTable->insert(newGlobal, rewriter.getInsertionPoint());
852
853 if (!isExternal && isUninitialized) {
854 rewriter.createBlock(&newGlobal.getInitializerRegion());
855 Value undef[] = {
856 LLVM::UndefOp::create(rewriter, newGlobal.getLoc(), arrayTy)};
857 LLVM::ReturnOp::create(rewriter, newGlobal.getLoc(), undef);
858 }
859 return success();
860 }
861};
862
863/// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
864/// the first element stashed into the descriptor. This reuses
865/// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
866struct GetGlobalMemrefOpLowering
867 : public ConvertOpToLLVMPattern<memref::GetGlobalOp> {
868 using ConvertOpToLLVMPattern<memref::GetGlobalOp>::ConvertOpToLLVMPattern;
869
870 /// Buffer "allocation" for memref.get_global op is getting the address of
871 /// the global variable referenced.
872 LogicalResult
873 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
874 ConversionPatternRewriter &rewriter) const override {
875 auto loc = op.getLoc();
876 MemRefType memRefType = op.getType();
877 if (!isConvertibleAndHasIdentityMaps(memRefType))
878 return rewriter.notifyMatchFailure(op, "incompatible memref type");
879
880 // Get actual sizes of the memref as values: static sizes are constant
881 // values and dynamic sizes are passed to 'alloc' as operands. In case of
882 // zero-dimensional memref, assume a scalar (size 1).
884 SmallVector<Value, 4> strides;
885 Value sizeBytes;
886
887 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
888 rewriter, sizes, strides, sizeBytes, !false);
889
890 MemRefType type = cast<MemRefType>(op.getResult().getType());
891
892 // This is called after a type conversion, which would have failed if this
893 // call fails.
894 FailureOr<unsigned> maybeAddressSpace =
895 getTypeConverter()->getMemRefAddressSpace(type);
896 assert(succeeded(maybeAddressSpace) && "unsupported address space");
897 unsigned memSpace = *maybeAddressSpace;
898
899 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
900 auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace);
901 auto addressOf =
902 LLVM::AddressOfOp::create(rewriter, loc, ptrTy, op.getName());
903
904 // Get the address of the first element in the array by creating a GEP with
905 // the address of the GV as the base, and (rank + 1) number of 0 indices.
906 auto gep =
907 LLVM::GEPOp::create(rewriter, loc, ptrTy, arrayTy, addressOf,
908 SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));
909
910 // We do not expect the memref obtained using `memref.get_global` to be
911 // ever deallocated. Set the allocated pointer to be known bad value to
912 // help debug if that ever happens.
913 auto intPtrType = getIntPtrType(memSpace);
914 Value deadBeefConst =
915 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
916 auto deadBeefPtr =
917 LLVM::IntToPtrOp::create(rewriter, loc, ptrTy, deadBeefConst);
918
919 // Both allocated and aligned pointers are same. We could potentially stash
920 // a nullptr for the allocated pointer since we do not expect any dealloc.
921 // Create the MemRef descriptor.
922 auto memRefDescriptor = this->createMemRefDescriptor(
923 loc, memRefType, deadBeefPtr, gep, sizes, strides, rewriter);
924
925 // Return the final value of the descriptor.
926 rewriter.replaceOp(op, {memRefDescriptor});
927 return success();
928 }
929};
930
931// Load operation is lowered to obtaining a pointer to the indexed element
932// and loading it.
933struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
934 using Base::Base;
935
936 LogicalResult
937 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
938 ConversionPatternRewriter &rewriter) const override {
939 auto type = loadOp.getMemRefType();
940
941 // Per memref.load spec, the indices must be in-bounds:
942 // 0 <= idx < dim_size, and additionally all offsets are non-negative,
943 // hence inbounds and nuw are used when lowering to llvm.getelementptr.
944 Value dataPtr = getStridedElementPtr(rewriter, loadOp.getLoc(), type,
945 adaptor.getMemref(),
946 adaptor.getIndices(), kNoWrapFlags);
947 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
948 loadOp, typeConverter->convertType(type.getElementType()), dataPtr,
949 loadOp.getAlignment().value_or(0), false, loadOp.getNontemporal());
950 return success();
951 }
952};
953
954// Store operation is lowered to obtaining a pointer to the indexed element,
955// and storing the given value to it.
956struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
957 using Base::Base;
958
959 LogicalResult
960 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
961 ConversionPatternRewriter &rewriter) const override {
962 auto type = op.getMemRefType();
963
964 // Per memref.store spec, the indices must be in-bounds:
965 // 0 <= idx < dim_size, and additionally all offsets are non-negative,
966 // hence inbounds and nuw are used when lowering to llvm.getelementptr.
967 Value dataPtr =
968 getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getMemref(),
969 adaptor.getIndices(), kNoWrapFlags);
970 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
971 op.getAlignment().value_or(0),
972 false, op.getNontemporal());
973 return success();
974 }
975};
976
977// The prefetch operation is lowered in a way similar to the load operation
978// except that the llvm.prefetch operation is used for replacement.
979struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
980 using Base::Base;
981
982 LogicalResult
983 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
984 ConversionPatternRewriter &rewriter) const override {
985 auto type = prefetchOp.getMemRefType();
986 auto loc = prefetchOp.getLoc();
987
988 Value dataPtr = getStridedElementPtr(
989 rewriter, loc, type, adaptor.getMemref(), adaptor.getIndices());
990
991 // Replace with llvm.prefetch.
992 IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
993 IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
994 IntegerAttr isData =
995 rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache());
996 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
997 localityHint, isData);
998 return success();
999 }
1000};
1001
1002struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
1004
1005 LogicalResult
1006 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
1007 ConversionPatternRewriter &rewriter) const override {
1008 Location loc = op.getLoc();
1009 Type operandType = op.getMemref().getType();
1010 if (isa<UnrankedMemRefType>(operandType)) {
1011 UnrankedMemRefDescriptor desc(adaptor.getMemref());
1012 rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
1013 return success();
1014 }
1015 if (auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
1016 Type indexType = getIndexType();
1017 rewriter.replaceOp(op,
1018 {createIndexAttrConstant(rewriter, loc, indexType,
1019 rankedMemRefType.getRank())});
1020 return success();
1021 }
1022 return failure();
1023 }
1024};
1025
1026struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
1028
1029 LogicalResult
1030 matchAndRewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
1031 ConversionPatternRewriter &rewriter) const override {
1032 Type srcType = memRefCastOp.getOperand().getType();
1033 Type dstType = memRefCastOp.getType();
1034
1035 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
1036 // used for type erasure. For now they must preserve underlying element type
1037 // and require source and result type to have the same rank. Therefore,
1038 // perform a sanity check that the underlying structs are the same. Once op
1039 // semantics are relaxed we can revisit.
1040 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
1041 if (typeConverter->convertType(srcType) !=
1042 typeConverter->convertType(dstType))
1043 return failure();
1044
1045 // Unranked to unranked cast is disallowed
1046 if (isa<UnrankedMemRefType>(srcType) && isa<UnrankedMemRefType>(dstType))
1047 return failure();
1048
1049 auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
1050 auto loc = memRefCastOp.getLoc();
1051
1052 // For ranked/ranked case, just keep the original descriptor.
1053 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType)) {
1054 rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
1055 return success();
1056 }
1057
1058 if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
1059 // Casting ranked to unranked memref type
1060 // Set the rank in the destination from the memref type
1061 // Allocate space on the stack and copy the src memref descriptor
1062 // Set the ptr in the destination to the stack space
1063 auto srcMemRefType = cast<MemRefType>(srcType);
1064 int64_t rank = srcMemRefType.getRank();
1065 // ptr = AllocaOp sizeof(MemRefDescriptor)
1066 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
1067 loc, adaptor.getSource(), rewriter);
1068
1069 // rank = ConstantOp srcRank
1070 auto rankVal = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1071 rewriter.getIndexAttr(rank));
1072 // poison = PoisonOp
1073 UnrankedMemRefDescriptor memRefDesc =
1074 UnrankedMemRefDescriptor::poison(rewriter, loc, targetStructType);
1075 // d1 = InsertValueOp poison, rank, 0
1076 memRefDesc.setRank(rewriter, loc, rankVal);
1077 // d2 = InsertValueOp d1, ptr, 1
1078 memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
1079 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
1080
1081 } else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
1082 // Casting from unranked type to ranked.
1083 // The operation is assumed to be doing a correct cast. If the destination
1084 // type mismatches the unranked the type, it is undefined behavior.
1085 UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
1086 // ptr = ExtractValueOp src, 1
1087 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
1088
1089 // struct = LoadOp ptr
1090 auto loadOp = LLVM::LoadOp::create(rewriter, loc, targetStructType, ptr);
1091 rewriter.replaceOp(memRefCastOp, loadOp.getResult());
1092 } else {
1093 llvm_unreachable("Unsupported unranked memref to unranked memref cast");
1094 }
1095
1096 return success();
1097 }
1098};
1099
1100/// Pattern to lower a `memref.copy` to llvm.
1101///
1102/// For memrefs with identity layouts, the copy is lowered to the llvm
1103/// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
1104/// to the generic `MemrefCopyFn`.
1105class MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
1106 SymbolTableCollection *symbolTables = nullptr;
1107
1108public:
1109 explicit MemRefCopyOpLowering(const LLVMTypeConverter &typeConverter,
1110 SymbolTableCollection *symbolTables = nullptr,
1111 PatternBenefit benefit = 1)
1112 : ConvertOpToLLVMPattern<memref::CopyOp>(typeConverter, benefit),
1113 symbolTables(symbolTables) {}
1114
1115 LogicalResult
1116 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
1117 ConversionPatternRewriter &rewriter) const {
1118 auto loc = op.getLoc();
1119 auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
1120
1121 MemRefDescriptor srcDesc(adaptor.getSource());
1122
1123 // Compute number of elements.
1124 Value numElements = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1125 rewriter.getIndexAttr(1));
1126 for (int pos = 0; pos < srcType.getRank(); ++pos) {
1127 auto size = srcDesc.size(rewriter, loc, pos);
1128 numElements = LLVM::MulOp::create(rewriter, loc, numElements, size);
1129 }
1130
1131 // Get element size.
1132 auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
1133 // Compute total.
1134 Value totalSize =
1135 LLVM::MulOp::create(rewriter, loc, numElements, sizeInBytes);
1136
1137 Type elementType = typeConverter->convertType(srcType.getElementType());
1138
1139 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
1140 Value srcOffset = srcDesc.offset(rewriter, loc);
1141 Value srcPtr = LLVM::GEPOp::create(rewriter, loc, srcBasePtr.getType(),
1142 elementType, srcBasePtr, srcOffset);
1143 MemRefDescriptor targetDesc(adaptor.getTarget());
1144 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
1145 Value targetOffset = targetDesc.offset(rewriter, loc);
1146 Value targetPtr =
1147 LLVM::GEPOp::create(rewriter, loc, targetBasePtr.getType(), elementType,
1148 targetBasePtr, targetOffset);
1149 LLVM::MemcpyOp::create(rewriter, loc, targetPtr, srcPtr, totalSize,
1150 /*isVolatile=*/false);
1151 rewriter.eraseOp(op);
1152
1153 return success();
1154 }
1155
1156 LogicalResult
1157 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
1158 ConversionPatternRewriter &rewriter) const {
1159 auto loc = op.getLoc();
1160 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1161 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1162
1163 // First make sure we have an unranked memref descriptor representation.
1164 auto makeUnranked = [&, this](Value ranked, MemRefType type) {
1165 auto rank = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1166 type.getRank());
1167 auto *typeConverter = getTypeConverter();
1168 auto ptr =
1169 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
1170
1171 auto unrankedType =
1172 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
1174 rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, ptr});
1175 };
1176
1177 // Save stack position before promoting descriptors
1178 auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType());
1179
1180 auto srcMemRefType = dyn_cast<MemRefType>(srcType);
1181 Value unrankedSource =
1182 srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
1183 : adaptor.getSource();
1184 auto targetMemRefType = dyn_cast<MemRefType>(targetType);
1185 Value unrankedTarget =
1186 targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
1187 : adaptor.getTarget();
1188
1189 // Now promote the unranked descriptors to the stack.
1190 auto one = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1191 rewriter.getIndexAttr(1));
1192 auto promote = [&](Value desc) {
1193 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1194 auto allocated =
1195 LLVM::AllocaOp::create(rewriter, loc, ptrType, desc.getType(), one);
1196 LLVM::StoreOp::create(rewriter, loc, desc, allocated);
1197 return allocated;
1198 };
1199
1200 auto sourcePtr = promote(unrankedSource);
1201 auto targetPtr = promote(unrankedTarget);
1202
1203 // Derive size from llvm.getelementptr which will account for any
1204 // potential alignment
1205 auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
1207 rewriter, op->getParentOfType<ModuleOp>(), getIndexType(),
1208 sourcePtr.getType(), symbolTables);
1209 if (failed(copyFn))
1210 return failure();
1211 LLVM::CallOp::create(rewriter, loc, copyFn.value(),
1212 ValueRange{elemSize, sourcePtr, targetPtr});
1213
1214 // Restore stack used for descriptors
1215 LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp);
1216
1217 rewriter.eraseOp(op);
1218
1219 return success();
1220 }
1221
1222 LogicalResult
1223 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
1224 ConversionPatternRewriter &rewriter) const override {
1225 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1226 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1227
1228 auto isContiguousMemrefType = [&](BaseMemRefType type) {
1229 auto memrefType = dyn_cast<mlir::MemRefType>(type);
1230 // We can use memcpy for memrefs if they have an identity layout or are
1231 // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
1232 // special case handled by memrefCopy.
1233 return memrefType &&
1234 (memrefType.getLayout().isIdentity() ||
1235 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
1237 };
1238
1239 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
1240 return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
1241
1242 return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
1243 }
1244};
1245
1246struct MemorySpaceCastOpLowering
1247 : public ConvertOpToLLVMPattern<memref::MemorySpaceCastOp> {
1249 memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern;
1250
1251 LogicalResult
1252 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
1253 ConversionPatternRewriter &rewriter) const override {
1254 Location loc = op.getLoc();
1255
1256 Type resultType = op.getDest().getType();
1257 if (auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
1258 auto resultDescType =
1259 cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR));
1260 Type newPtrType = resultDescType.getBody()[0];
1261
1262 SmallVector<Value> descVals;
1263 MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
1264 descVals);
1265 descVals[0] =
1266 LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[0]);
1267 descVals[1] =
1268 LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[1]);
1269 Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
1270 resultTypeR, descVals);
1271 rewriter.replaceOp(op, result);
1272 return success();
1273 }
1274 if (auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
1275 // Since the type converter won't be doing this for us, get the address
1276 // space.
1277 auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
1278 FailureOr<unsigned> maybeSourceAddrSpace =
1279 getTypeConverter()->getMemRefAddressSpace(sourceType);
1280 if (failed(maybeSourceAddrSpace))
1281 return rewriter.notifyMatchFailure(loc,
1282 "non-integer source address space");
1283 unsigned sourceAddrSpace = *maybeSourceAddrSpace;
1284 FailureOr<unsigned> maybeResultAddrSpace =
1285 getTypeConverter()->getMemRefAddressSpace(resultTypeU);
1286 if (failed(maybeResultAddrSpace))
1287 return rewriter.notifyMatchFailure(loc,
1288 "non-integer result address space");
1289 unsigned resultAddrSpace = *maybeResultAddrSpace;
1290
1291 UnrankedMemRefDescriptor sourceDesc(adaptor.getSource());
1292 Value rank = sourceDesc.rank(rewriter, loc);
1293 Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
1294
1295 // Create and allocate storage for new memref descriptor.
1297 rewriter, loc, typeConverter->convertType(resultTypeU));
1298 result.setRank(rewriter, loc, rank);
1299 Value resultUnderlyingSize = UnrankedMemRefDescriptor::computeSize(
1300 rewriter, loc, *getTypeConverter(), result, resultAddrSpace);
1301 Value resultUnderlyingDesc =
1302 LLVM::AllocaOp::create(rewriter, loc, getPtrType(),
1303 rewriter.getI8Type(), resultUnderlyingSize);
1304 result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
1305
1306 // Copy pointers, performing address space casts.
1307 auto sourceElemPtrType =
1308 LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace);
1309 auto resultElemPtrType =
1310 LLVM::LLVMPointerType::get(rewriter.getContext(), resultAddrSpace);
1311
1312 Value allocatedPtr = sourceDesc.allocatedPtr(
1313 rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
1314 Value alignedPtr =
1315 sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
1316 sourceUnderlyingDesc, sourceElemPtrType);
1317 allocatedPtr = LLVM::AddrSpaceCastOp::create(
1318 rewriter, loc, resultElemPtrType, allocatedPtr);
1319 alignedPtr = LLVM::AddrSpaceCastOp::create(rewriter, loc,
1320 resultElemPtrType, alignedPtr);
1321
1322 result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
1323 resultElemPtrType, allocatedPtr);
1324 result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
1325 resultUnderlyingDesc, resultElemPtrType, alignedPtr);
1326
1327 // Copy all the index-valued operands.
1328 Value sourceIndexVals =
1329 sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1330 sourceUnderlyingDesc, sourceElemPtrType);
1331 Value resultIndexVals =
1332 result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1333 resultUnderlyingDesc, resultElemPtrType);
1334
1335 int64_t bytesToSkip =
1336 2 * llvm::divideCeil(
1337 getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
1338 Value bytesToSkipConst = LLVM::ConstantOp::create(
1339 rewriter, loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
1340 Value copySize =
1341 LLVM::SubOp::create(rewriter, loc, getIndexType(),
1342 resultUnderlyingSize, bytesToSkipConst);
1343 LLVM::MemcpyOp::create(rewriter, loc, resultIndexVals, sourceIndexVals,
1344 copySize, /*isVolatile=*/false);
1345
1346 rewriter.replaceOp(op, ValueRange{result});
1347 return success();
1348 }
1349 return rewriter.notifyMatchFailure(loc, "unexpected memref type");
1350 }
1351};
1352
1353/// Extracts allocated, aligned pointers and offset from a ranked or unranked
1354/// memref type. In unranked case, the fields are extracted from the underlying
1355/// ranked descriptor.
1356static void extractPointersAndOffset(Location loc,
1357 ConversionPatternRewriter &rewriter,
1358 const LLVMTypeConverter &typeConverter,
1359 Value originalOperand,
1360 Value convertedOperand,
1361 Value *allocatedPtr, Value *alignedPtr,
1362 Value *offset = nullptr) {
1363 Type operandType = originalOperand.getType();
1364 if (isa<MemRefType>(operandType)) {
1365 MemRefDescriptor desc(convertedOperand);
1366 *allocatedPtr = desc.allocatedPtr(rewriter, loc);
1367 *alignedPtr = desc.alignedPtr(rewriter, loc);
1368 if (offset != nullptr)
1369 *offset = desc.offset(rewriter, loc);
1370 return;
1371 }
1372
1373 // These will all cause assert()s on unconvertible types.
1374 unsigned memorySpace = *typeConverter.getMemRefAddressSpace(
1375 cast<UnrankedMemRefType>(operandType));
1376 auto elementPtrType =
1377 LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace);
1378
1379 // Extract pointer to the underlying ranked memref descriptor and cast it to
1380 // ElemType**.
1381 UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
1382 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1383
1385 rewriter, loc, underlyingDescPtr, elementPtrType);
1387 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1388 if (offset != nullptr) {
1390 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1391 }
1392}
1393
1394struct MemRefReinterpretCastOpLowering
1395 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1397 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1398
1399 LogicalResult
1400 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1401 ConversionPatternRewriter &rewriter) const override {
1402 Type srcType = castOp.getSource().getType();
1403
1404 Value descriptor;
1405 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1406 adaptor, &descriptor)))
1407 return failure();
1408 rewriter.replaceOp(castOp, {descriptor});
1409 return success();
1410 }
1411
1412private:
1413 LogicalResult convertSourceMemRefToDescriptor(
1414 ConversionPatternRewriter &rewriter, Type srcType,
1415 memref::ReinterpretCastOp castOp,
1416 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1417 MemRefType targetMemRefType =
1418 cast<MemRefType>(castOp.getResult().getType());
1419 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1420 typeConverter->convertType(targetMemRefType));
1421 if (!llvmTargetDescriptorTy)
1422 return failure();
1423
1424 // Create descriptor.
1425 Location loc = castOp.getLoc();
1426 auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1427
1428 // Set allocated and aligned pointers.
1429 Value allocatedPtr, alignedPtr;
1430 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1431 castOp.getSource(), adaptor.getSource(),
1432 &allocatedPtr, &alignedPtr);
1433 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1434 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1435
1436 // Set offset.
1437 if (castOp.isDynamicOffset(0))
1438 desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1439 else
1440 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1441
1442 // Set sizes and strides.
1443 unsigned dynSizeId = 0;
1444 unsigned dynStrideId = 0;
1445 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1446 if (castOp.isDynamicSize(i))
1447 desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1448 else
1449 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1450
1451 if (castOp.isDynamicStride(i))
1452 desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1453 else
1454 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1455 }
1456 *descriptor = desc;
1457 return success();
1458 }
1459};
1460
1461struct MemRefReshapeOpLowering
1462 : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1463 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
1464
1465 LogicalResult
1466 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1467 ConversionPatternRewriter &rewriter) const override {
1468 Type srcType = reshapeOp.getSource().getType();
1469
1470 Value descriptor;
1471 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1472 adaptor, &descriptor)))
1473 return failure();
1474 rewriter.replaceOp(reshapeOp, {descriptor});
1475 return success();
1476 }
1477
1478private:
1479 LogicalResult
1480 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1481 Type srcType, memref::ReshapeOp reshapeOp,
1482 memref::ReshapeOp::Adaptor adaptor,
1483 Value *descriptor) const {
1484 auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
1485 if (shapeMemRefType.hasStaticShape()) {
1486 MemRefType targetMemRefType =
1487 cast<MemRefType>(reshapeOp.getResult().getType());
1488 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1489 typeConverter->convertType(targetMemRefType));
1490 if (!llvmTargetDescriptorTy)
1491 return failure();
1492
1493 // Create descriptor.
1494 Location loc = reshapeOp.getLoc();
1495 auto desc =
1496 MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1497
1498 // Set allocated and aligned pointers.
1499 Value allocatedPtr, alignedPtr;
1500 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1501 reshapeOp.getSource(), adaptor.getSource(),
1502 &allocatedPtr, &alignedPtr);
1503 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1504 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1505
1506 // Extract the offset and strides from the type.
1507 int64_t offset;
1508 SmallVector<int64_t> strides;
1509 if (failed(targetMemRefType.getStridesAndOffset(strides, offset)))
1510 return rewriter.notifyMatchFailure(
1511 reshapeOp, "failed to get stride and offset exprs");
1512
1513 if (!isStaticStrideOrOffset(offset))
1514 return rewriter.notifyMatchFailure(reshapeOp,
1515 "dynamic offset is unsupported");
1516
1517 desc.setConstantOffset(rewriter, loc, offset);
1518
1519 assert(targetMemRefType.getLayout().isIdentity() &&
1520 "Identity layout map is a precondition of a valid reshape op");
1521
1522 Type indexType = getIndexType();
1523 Value stride = nullptr;
1524 int64_t targetRank = targetMemRefType.getRank();
1525 for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1526 if (ShapedType::isStatic(strides[i])) {
1527 // If the stride for this dimension is dynamic, then use the product
1528 // of the sizes of the inner dimensions.
1529 stride =
1530 createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
1531 } else if (!stride) {
1532 // `stride` is null only in the first iteration of the loop. However,
1533 // since the target memref has an identity layout, we can safely set
1534 // the innermost stride to 1.
1535 stride = createIndexAttrConstant(rewriter, loc, indexType, 1);
1536 }
1537
1538 Value dimSize;
1539 // If the size of this dimension is dynamic, then load it at runtime
1540 // from the shape operand.
1541 if (!targetMemRefType.isDynamicDim(i)) {
1542 dimSize = createIndexAttrConstant(rewriter, loc, indexType,
1543 targetMemRefType.getDimSize(i));
1544 } else {
1545 Value shapeOp = reshapeOp.getShape();
1546 Value index = createIndexAttrConstant(rewriter, loc, indexType, i);
1547 dimSize = memref::LoadOp::create(rewriter, loc, shapeOp, index);
1548 Type indexType = getIndexType();
1549 if (dimSize.getType() != indexType)
1550 dimSize = typeConverter->materializeTargetConversion(
1551 rewriter, loc, indexType, dimSize);
1552 assert(dimSize && "Invalid memref element type");
1553 }
1554
1555 desc.setSize(rewriter, loc, i, dimSize);
1556 desc.setStride(rewriter, loc, i, stride);
1557
1558 // Prepare the stride value for the next dimension.
1559 stride = LLVM::MulOp::create(rewriter, loc, stride, dimSize);
1560 }
1561
1562 *descriptor = desc;
1563 return success();
1564 }
1565
1566 // The shape is a rank-1 tensor with unknown length.
1567 Location loc = reshapeOp.getLoc();
1568 MemRefDescriptor shapeDesc(adaptor.getShape());
1569 Value resultRank = shapeDesc.size(rewriter, loc, 0);
1570
1571 // Extract address space and element type.
1572 auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1573 unsigned addressSpace =
1574 *getTypeConverter()->getMemRefAddressSpace(targetType);
1575
1576 // Create the unranked memref descriptor that holds the ranked one. The
1577 // inner descriptor is allocated on stack.
1578 auto targetDesc = UnrankedMemRefDescriptor::poison(
1579 rewriter, loc, typeConverter->convertType(targetType));
1580 targetDesc.setRank(rewriter, loc, resultRank);
1582 rewriter, loc, *getTypeConverter(), targetDesc, addressSpace);
1583 Value underlyingDescPtr = LLVM::AllocaOp::create(
1584 rewriter, loc, getPtrType(), IntegerType::get(getContext(), 8),
1585 allocationSize);
1586 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1587
1588 // Extract pointers and offset from the source memref.
1589 Value allocatedPtr, alignedPtr, offset;
1590 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1591 reshapeOp.getSource(), adaptor.getSource(),
1592 &allocatedPtr, &alignedPtr, &offset);
1593
1594 // Set pointers and offset.
1595 auto elementPtrType =
1596 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
1597
1598 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1599 elementPtrType, allocatedPtr);
1600 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1601 underlyingDescPtr, elementPtrType,
1602 alignedPtr);
1603 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1604 underlyingDescPtr, elementPtrType,
1605 offset);
1606
1607 // Use the offset pointer as base for further addressing. Copy over the new
1608 // shape and compute strides. For this, we create a loop from rank-1 to 0.
1610 rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
1612 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1613 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1614 Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1);
1615 Value resultRankMinusOne =
1616 LLVM::SubOp::create(rewriter, loc, resultRank, oneIndex);
1617
1618 Block *initBlock = rewriter.getInsertionBlock();
1619 Type indexType = getTypeConverter()->getIndexType();
1620 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1621
1622 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1623 {indexType, indexType}, {loc, loc});
1624
1625 // Move the remaining initBlock ops to condBlock.
1626 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1627 rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1628
1629 rewriter.setInsertionPointToEnd(initBlock);
1630 LLVM::BrOp::create(rewriter, loc,
1631 ValueRange({resultRankMinusOne, oneIndex}), condBlock);
1632 rewriter.setInsertionPointToStart(condBlock);
1633 Value indexArg = condBlock->getArgument(0);
1634 Value strideArg = condBlock->getArgument(1);
1635
1636 Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0);
1637 Value pred = LLVM::ICmpOp::create(
1638 rewriter, loc, IntegerType::get(rewriter.getContext(), 1),
1639 LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1640
1641 Block *bodyBlock =
1642 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1643 rewriter.setInsertionPointToStart(bodyBlock);
1644
1645 // Copy size from shape to descriptor.
1646 auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1647 Value sizeLoadGep = LLVM::GEPOp::create(
1648 rewriter, loc, llvmIndexPtrType,
1649 typeConverter->convertType(shapeMemRefType.getElementType()),
1650 shapeOperandPtr, indexArg);
1651 Value size = LLVM::LoadOp::create(rewriter, loc, indexType, sizeLoadGep);
1652 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1653 targetSizesBase, indexArg, size);
1654
1655 // Write stride value and compute next one.
1656 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1657 targetStridesBase, indexArg, strideArg);
1658 Value nextStride = LLVM::MulOp::create(rewriter, loc, strideArg, size);
1659
1660 // Decrement loop counter and branch back.
1661 Value decrement = LLVM::SubOp::create(rewriter, loc, indexArg, oneIndex);
1662 LLVM::BrOp::create(rewriter, loc, ValueRange({decrement, nextStride}),
1663 condBlock);
1664
1665 Block *remainder =
1666 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1667
1668 // Hook up the cond exit to the remainder.
1669 rewriter.setInsertionPointToEnd(condBlock);
1670 LLVM::CondBrOp::create(rewriter, loc, pred, bodyBlock, ValueRange(),
1671 remainder, ValueRange());
1672
1673 // Reset position to beginning of new remainder block.
1674 rewriter.setInsertionPointToStart(remainder);
1675
1676 *descriptor = targetDesc;
1677 return success();
1678 }
1679};
1680
1681/// RessociatingReshapeOp must be expanded before we reach this stage.
1682/// Report that information.
1683template <typename ReshapeOp>
1684class ReassociatingReshapeOpConversion
1685 : public ConvertOpToLLVMPattern<ReshapeOp> {
1686public:
1688 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1689
1690 LogicalResult
1691 matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1692 ConversionPatternRewriter &rewriter) const override {
1693 return rewriter.notifyMatchFailure(
1694 reshapeOp,
1695 "reassociation operations should have been expanded beforehand");
1696 }
1697};
1698
1699/// Subviews must be expanded before we reach this stage.
1700/// Report that information.
1701struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1702 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1703
1704 LogicalResult
1705 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1706 ConversionPatternRewriter &rewriter) const override {
1707 return rewriter.notifyMatchFailure(
1708 subViewOp, "subview operations should have been expanded beforehand");
1709 }
1710};
1711
1712/// Conversion pattern that transforms a transpose op into:
1713/// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
1714/// 2. A load of the ViewDescriptor from the pointer allocated in 1.
1715/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1716/// and stride. Size and stride are permutations of the original values.
1717/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1718/// The transpose op is replaced by the alloca'ed pointer.
1719class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1720public:
1721 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1722
1723 LogicalResult
1724 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1725 ConversionPatternRewriter &rewriter) const override {
1726 auto loc = transposeOp.getLoc();
1727 MemRefDescriptor viewMemRef(adaptor.getIn());
1728
1729 // No permutation, early exit.
1730 if (transposeOp.getPermutation().isIdentity())
1731 return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1732
1733 auto targetMemRef = MemRefDescriptor::poison(
1734 rewriter, loc,
1735 typeConverter->convertType(transposeOp.getIn().getType()));
1736
1737 // Copy the base and aligned pointers from the old descriptor to the new
1738 // one.
1739 targetMemRef.setAllocatedPtr(rewriter, loc,
1740 viewMemRef.allocatedPtr(rewriter, loc));
1741 targetMemRef.setAlignedPtr(rewriter, loc,
1742 viewMemRef.alignedPtr(rewriter, loc));
1743
1744 // Copy the offset pointer from the old descriptor to the new one.
1745 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1746
1747 // Iterate over the dimensions and apply size/stride permutation:
1748 // When enumerating the results of the permutation map, the enumeration
1749 // index is the index into the target dimensions and the DimExpr points to
1750 // the dimension of the source memref.
1751 for (const auto &en :
1752 llvm::enumerate(transposeOp.getPermutation().getResults())) {
1753 int targetPos = en.index();
1754 int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
1755 targetMemRef.setSize(rewriter, loc, targetPos,
1756 viewMemRef.size(rewriter, loc, sourcePos));
1757 targetMemRef.setStride(rewriter, loc, targetPos,
1758 viewMemRef.stride(rewriter, loc, sourcePos));
1759 }
1760
1761 rewriter.replaceOp(transposeOp, {targetMemRef});
1762 return success();
1763 }
1764};
1765
1766/// Conversion pattern that transforms an op into:
1767/// 1. An `llvm.mlir.undef` operation to create a memref descriptor
1768/// 2. Updates to the descriptor to introduce the data ptr, offset, size
1769/// and stride.
1770/// The view op is replaced by the descriptor.
1771struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1773
1774 // Build and return the value for the idx^th shape dimension, either by
1775 // returning the constant shape dimension or counting the proper dynamic size.
1776 Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1777 ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx,
1778 Type indexType) const {
1779 assert(idx < shape.size());
1780 if (ShapedType::isStatic(shape[idx]))
1781 return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]);
1782 // Count the number of dynamic dims in range [0, idx]
1783 unsigned nDynamic =
1784 llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1785 return dynamicSizes[nDynamic];
1786 }
1787
1788 // Build and return the idx^th stride, either by returning the constant stride
1789 // or by computing the dynamic stride from the current `runningStride` and
1790 // `nextSize`. The caller should keep a running stride and update it with the
1791 // result returned by this function.
1792 Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1793 ArrayRef<int64_t> strides, Value nextSize,
1794 Value runningStride, unsigned idx, Type indexType) const {
1795 assert(idx < strides.size());
1796 if (ShapedType::isStatic(strides[idx]))
1797 return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]);
1798 if (nextSize)
1799 return runningStride
1800 ? LLVM::MulOp::create(rewriter, loc, runningStride, nextSize)
1801 : nextSize;
1802 assert(!runningStride);
1803 return createIndexAttrConstant(rewriter, loc, indexType, 1);
1804 }
1805
1806 LogicalResult
1807 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1808 ConversionPatternRewriter &rewriter) const override {
1809 auto loc = viewOp.getLoc();
1810
1811 auto viewMemRefType = viewOp.getType();
1812 auto targetElementTy =
1813 typeConverter->convertType(viewMemRefType.getElementType());
1814 auto targetDescTy = typeConverter->convertType(viewMemRefType);
1815 if (!targetDescTy || !targetElementTy ||
1816 !LLVM::isCompatibleType(targetElementTy) ||
1817 !LLVM::isCompatibleType(targetDescTy))
1818 return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1819 failure();
1820
1821 int64_t offset;
1823 auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset);
1824 if (failed(successStrides))
1825 return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1826 assert(offset == 0 && "expected offset to be 0");
1827
1828 // Target memref must be contiguous in memory (innermost stride is 1), or
1829 // empty (special case when at least one of the memref dimensions is 0).
1830 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1831 return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1832 failure();
1833
1834 // Create the descriptor.
1835 MemRefDescriptor sourceMemRef(adaptor.getSource());
1836 auto targetMemRef = MemRefDescriptor::poison(rewriter, loc, targetDescTy);
1837
1838 // Field 1: Copy the allocated pointer, used for malloc/free.
1839 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1840 auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1841 targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
1842
1843 // Field 2: Copy the actual aligned pointer to payload.
1844 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1845 alignedPtr = LLVM::GEPOp::create(
1846 rewriter, loc, alignedPtr.getType(),
1847 typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
1848 adaptor.getByteShift());
1849
1850 targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
1851
1852 Type indexType = getIndexType();
1853 // Field 3: The offset in the resulting type must be 0. This is
1854 // because of the type change: an offset on srcType* may not be
1855 // expressible as an offset on dstType*.
1856 targetMemRef.setOffset(
1857 rewriter, loc,
1858 createIndexAttrConstant(rewriter, loc, indexType, offset));
1859
1860 // Early exit for 0-D corner case.
1861 if (viewMemRefType.getRank() == 0)
1862 return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1863
1864 // Fields 4 and 5: Update sizes and strides.
1865 Value stride = nullptr, nextSize = nullptr;
1866 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1867 // Update size.
1868 Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1869 adaptor.getSizes(), i, indexType);
1870 targetMemRef.setSize(rewriter, loc, i, size);
1871 // Update stride.
1872 stride =
1873 getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
1874 targetMemRef.setStride(rewriter, loc, i, stride);
1875 nextSize = size;
1876 }
1877
1878 rewriter.replaceOp(viewOp, {targetMemRef});
1879 return success();
1880 }
1881};
1882
1883//===----------------------------------------------------------------------===//
1884// AtomicRMWOpLowering
1885//===----------------------------------------------------------------------===//
1886
1887/// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1888/// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1889static std::optional<LLVM::AtomicBinOp>
1890matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1891 switch (atomicOp.getKind()) {
1892 case arith::AtomicRMWKind::addf:
1893 return LLVM::AtomicBinOp::fadd;
1894 case arith::AtomicRMWKind::addi:
1895 return LLVM::AtomicBinOp::add;
1896 case arith::AtomicRMWKind::assign:
1897 return LLVM::AtomicBinOp::xchg;
1898 case arith::AtomicRMWKind::maximumf:
1899 // TODO: remove this by end of 2025.
1900 LDBG() << "the lowering of memref.atomicrmw maximumf changed "
1901 "from fmax to fmaximum, expect more NaNs";
1902 return LLVM::AtomicBinOp::fmaximum;
1903 case arith::AtomicRMWKind::maxnumf:
1904 return LLVM::AtomicBinOp::fmax;
1905 case arith::AtomicRMWKind::maxs:
1906 return LLVM::AtomicBinOp::max;
1907 case arith::AtomicRMWKind::maxu:
1908 return LLVM::AtomicBinOp::umax;
1909 case arith::AtomicRMWKind::minimumf:
1910 // TODO: remove this by end of 2025.
1911 LDBG() << "the lowering of memref.atomicrmw minimum changed "
1912 "from fmin to fminimum, expect more NaNs";
1913 return LLVM::AtomicBinOp::fminimum;
1914 case arith::AtomicRMWKind::minnumf:
1915 return LLVM::AtomicBinOp::fmin;
1916 case arith::AtomicRMWKind::mins:
1917 return LLVM::AtomicBinOp::min;
1918 case arith::AtomicRMWKind::minu:
1919 return LLVM::AtomicBinOp::umin;
1920 case arith::AtomicRMWKind::ori:
1921 return LLVM::AtomicBinOp::_or;
1922 case arith::AtomicRMWKind::xori:
1923 return LLVM::AtomicBinOp::_xor;
1924 case arith::AtomicRMWKind::andi:
1925 return LLVM::AtomicBinOp::_and;
1926 default:
1927 return std::nullopt;
1928 }
1929 llvm_unreachable("Invalid AtomicRMWKind");
1930}
1931
1932struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1933 using Base::Base;
1934
1935 LogicalResult
1936 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1937 ConversionPatternRewriter &rewriter) const override {
1938 auto maybeKind = matchSimpleAtomicOp(atomicOp);
1939 if (!maybeKind)
1940 return failure();
1941 auto memRefType = atomicOp.getMemRefType();
1942 SmallVector<int64_t> strides;
1943 int64_t offset;
1944 if (failed(memRefType.getStridesAndOffset(strides, offset)))
1945 return failure();
1946 auto dataPtr =
1947 getStridedElementPtr(rewriter, atomicOp.getLoc(), memRefType,
1948 adaptor.getMemref(), adaptor.getIndices());
1949 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1950 atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1951 LLVM::AtomicOrdering::acq_rel);
1952 return success();
1953 }
1954};
1955
1956/// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index.
1957class ConvertExtractAlignedPointerAsIndex
1958 : public ConvertOpToLLVMPattern<memref::ExtractAlignedPointerAsIndexOp> {
1959public:
1961 memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern;
1962
1963 LogicalResult
1964 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1965 OpAdaptor adaptor,
1966 ConversionPatternRewriter &rewriter) const override {
1967 BaseMemRefType sourceTy = extractOp.getSource().getType();
1968
1969 Value alignedPtr;
1970 if (sourceTy.hasRank()) {
1971 MemRefDescriptor desc(adaptor.getSource());
1972 alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc());
1973 } else {
1974 auto elementPtrTy = LLVM::LLVMPointerType::get(
1975 rewriter.getContext(), sourceTy.getMemorySpaceAsInt());
1976
1977 UnrankedMemRefDescriptor desc(adaptor.getSource());
1978 Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc());
1979
1981 rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
1982 elementPtrTy);
1983 }
1984
1985 rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
1986 extractOp, getTypeConverter()->getIndexType(), alignedPtr);
1987 return success();
1988 }
1989};
1990
1991/// Materialize the MemRef descriptor represented by the results of
1992/// ExtractStridedMetadataOp.
1993class ExtractStridedMetadataOpLowering
1994 : public ConvertOpToLLVMPattern<memref::ExtractStridedMetadataOp> {
1995public:
1997 memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern;
1998
1999 LogicalResult
2000 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
2001 OpAdaptor adaptor,
2002 ConversionPatternRewriter &rewriter) const override {
2003
2004 if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
2005 return failure();
2006
2007 // Create the descriptor.
2008 MemRefDescriptor sourceMemRef(adaptor.getSource());
2009 Location loc = extractStridedMetadataOp.getLoc();
2010 Value source = extractStridedMetadataOp.getSource();
2011
2012 auto sourceMemRefType = cast<MemRefType>(source.getType());
2013 int64_t rank = sourceMemRefType.getRank();
2014 SmallVector<Value> results;
2015 results.reserve(2 + rank * 2);
2016
2017 // Base buffer.
2018 Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc);
2019 Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
2021 rewriter, loc, *getTypeConverter(),
2022 cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
2023 baseBuffer, alignedBuffer);
2024 results.push_back((Value)dstMemRef);
2025
2026 // Offset.
2027 results.push_back(sourceMemRef.offset(rewriter, loc));
2028
2029 // Sizes.
2030 for (unsigned i = 0; i < rank; ++i)
2031 results.push_back(sourceMemRef.size(rewriter, loc, i));
2032 // Strides.
2033 for (unsigned i = 0; i < rank; ++i)
2034 results.push_back(sourceMemRef.stride(rewriter, loc, i));
2035
2036 rewriter.replaceOp(extractStridedMetadataOp, results);
2037 return success();
2038 }
2039};
2040
2041} // namespace
2042
2044 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
2045 SymbolTableCollection *symbolTables) {
2046 // clang-format off
2047 patterns.add<
2048 AllocaOpLowering,
2049 AllocaScopeOpLowering,
2050 AssumeAlignmentOpLowering,
2051 AtomicRMWOpLowering,
2052 ConvertExtractAlignedPointerAsIndex,
2053 DimOpLowering,
2054 DistinctObjectsOpLowering,
2055 ExtractStridedMetadataOpLowering,
2056 GenericAtomicRMWOpLowering,
2057 GetGlobalMemrefOpLowering,
2058 LoadOpLowering,
2059 MemRefCastOpLowering,
2060 MemRefReinterpretCastOpLowering,
2061 MemRefReshapeOpLowering,
2062 MemorySpaceCastOpLowering,
2063 PrefetchOpLowering,
2064 RankOpLowering,
2065 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
2066 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
2067 StoreOpLowering,
2068 SubViewOpLowering,
2070 ViewOpLowering>(converter);
2071 // clang-format on
2072 patterns.add<GlobalMemrefOpLowering, MemRefCopyOpLowering>(converter,
2073 symbolTables);
2074 auto allocLowering = converter.getOptions().allocLowering;
2076 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter,
2077 symbolTables);
2078 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
2079 patterns.add<AllocOpLowering, DeallocOpLowering>(converter, symbolTables);
2080}
2081
2082namespace {
2083struct FinalizeMemRefToLLVMConversionPass
2085 FinalizeMemRefToLLVMConversionPass> {
2086 using FinalizeMemRefToLLVMConversionPassBase::
2087 FinalizeMemRefToLLVMConversionPassBase;
2088
2089 void runOnOperation() override {
2090 Operation *op = getOperation();
2091 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
2093 dataLayoutAnalysis.getAtOrAbove(op));
2094 options.allocLowering =
2097
2098 options.useGenericFunctions = useGenericFunctions;
2099
2100 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
2101 options.overrideIndexBitwidth(indexBitwidth);
2102
2103 LLVMTypeConverter typeConverter(&getContext(), options,
2104 &dataLayoutAnalysis);
2106 SymbolTableCollection symbolTables;
2108 &symbolTables);
2110 target.addLegalOp<func::FuncOp>();
2111 if (failed(applyPartialConversion(op, target, std::move(patterns))))
2112 signalPassFailure();
2113 }
2114};
2115
2116/// Implement the interface to convert MemRef to LLVM.
2117struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
2119 void loadDependentDialects(MLIRContext *context) const final {
2120 context->loadDialect<LLVM::LLVMDialect>();
2121 }
2122
2123 /// Hook for derived dialect interface to provide conversion patterns
2124 /// and mark dialect legal for the conversion target.
2125 void populateConvertToLLVMConversionPatterns(
2126 ConversionTarget &target, LLVMTypeConverter &typeConverter,
2127 RewritePatternSet &patterns) const final {
2129 }
2130};
2131
2132} // namespace
2133
2135 registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
2136 dialect->addInterfaces<MemRefToLLVMDialectInterface>();
2137 });
2138}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:140
BlockArgument getArgument(unsigned i)
Definition Block.h:129
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
Operation & front()
Definition Block.h:153
Operation & back()
Definition Block.h:152
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgListType getArguments()
Definition Block.h:87
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:212
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:216
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:222
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:58
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) const
Promote the LLVM struct representation of one MemRef descriptor to stack and use pointer to struct to...
const LowerToLLVMOptions & getOptions() const
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
const DataLayoutAnalysis * getDataLayoutAnalysis() const
Returns the data layout analysis to query during conversion.
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Options to control the LLVM lowering.
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value bufferPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type)
Builds IR for getting the start address of the buffer represented by this memref: memref....
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
Value offset(OpBuilder &builder, Location loc)
Builds IR extracting the offset from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:248
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class represents a collection of SymbolTables.
virtual SymbolTable & getSymbolTable(Operation *op)
Lookup, or create, a symbol table for an operation.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
static void setOffset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
TODO: The following accessors don't take alignment rules between elements of the descriptor struct in...
static Value computeSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, UnrankedMemRefDescriptor desc, unsigned addressSpace)
Builds and returns IR computing the size in bytes (suitable for opaque allocation).
void setRank(OpBuilder &builder, Location loc, Value value)
Builds IR setting the rank in the descriptor.
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
static void setSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void setAlignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
Value rank(OpBuilder &builder, Location loc) const
Builds IR extracting the rank from the descriptor.
static Value offsetBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR for getting the pointer to the offset's location.
static Value offset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the offset from the descriptor.
static Value strideBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value)
Builds IR setting ranked memref descriptor ptr.
static void setStride(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
static Value sizeBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the pointer to the first element of the size array.
static Value alignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the aligned pointer from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType, Type unrankedDescriptorType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
Include the generated interface declarations.
void registerConvertMemRefToLLVMInterface(DialectRegistry &registry)
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
const FrozenRewritePatternSet & patterns
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.