MLIR 22.0.0git
MemRefToLLVM.cpp
Go to the documentation of this file.
1//===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
23#include "mlir/IR/AffineMap.h"
25#include "mlir/IR/IRMapping.h"
26#include "mlir/Pass/Pass.h"
27#include "llvm/Support/DebugLog.h"
28#include "llvm/Support/MathExtras.h"
29
30#include <optional>
31
32#define DEBUG_TYPE "memref-to-llvm"
33
34namespace mlir {
35#define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
36#include "mlir/Conversion/Passes.h.inc"
37} // namespace mlir
38
39using namespace mlir;
40
41static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags =
42 LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw;
43
44namespace {
45
46static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
47 return ShapedType::isStatic(strideOrOffset);
48}
49
50static FailureOr<LLVM::LLVMFuncOp>
51getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
52 Operation *module, SymbolTableCollection *symbolTables) {
53 bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
54
55 if (useGenericFn)
56 return LLVM::lookupOrCreateGenericFreeFn(b, module, symbolTables);
57
58 return LLVM::lookupOrCreateFreeFn(b, module, symbolTables);
59}
60
61static FailureOr<LLVM::LLVMFuncOp>
62getNotalignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
63 Operation *module, Type indexType,
64 SymbolTableCollection *symbolTables) {
65 bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
66 if (useGenericFn)
67 return LLVM::lookupOrCreateGenericAllocFn(b, module, indexType,
68 symbolTables);
69
70 return LLVM::lookupOrCreateMallocFn(b, module, indexType, symbolTables);
71}
72
73static FailureOr<LLVM::LLVMFuncOp>
74getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
75 Operation *module, Type indexType,
76 SymbolTableCollection *symbolTables) {
77 bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
78
79 if (useGenericFn)
80 return LLVM::lookupOrCreateGenericAlignedAllocFn(b, module, indexType,
81 symbolTables);
82
83 return LLVM::lookupOrCreateAlignedAllocFn(b, module, indexType, symbolTables);
84}
85
86/// Computes the aligned value for 'input' as follows:
87/// bumped = input + alignement - 1
88/// aligned = bumped - bumped % alignment
89static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
90 Value input, Value alignment) {
91 Value one = LLVM::ConstantOp::create(rewriter, loc, alignment.getType(),
92 rewriter.getIndexAttr(1));
93 Value bump = LLVM::SubOp::create(rewriter, loc, alignment, one);
94 Value bumped = LLVM::AddOp::create(rewriter, loc, input, bump);
95 Value mod = LLVM::URemOp::create(rewriter, loc, bumped, alignment);
96 return LLVM::SubOp::create(rewriter, loc, bumped, mod);
97}
98
99/// Computes the byte size for the MemRef element type.
100static unsigned getMemRefEltSizeInBytes(const LLVMTypeConverter *typeConverter,
101 MemRefType memRefType, Operation *op,
102 const DataLayout *defaultLayout) {
103 const DataLayout *layout = defaultLayout;
104 if (const DataLayoutAnalysis *analysis =
105 typeConverter->getDataLayoutAnalysis()) {
106 layout = &analysis->getAbove(op);
107 }
108 Type elementType = memRefType.getElementType();
109 if (auto memRefElementType = dyn_cast<MemRefType>(elementType))
110 return typeConverter->getMemRefDescriptorSize(memRefElementType, *layout);
111 if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
112 return typeConverter->getUnrankedMemRefDescriptorSize(memRefElementType,
113 *layout);
114 return layout->getTypeSize(elementType);
115}
116
117static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
118 Location loc, Value allocatedPtr,
119 MemRefType memRefType, Type elementPtrType,
120 const LLVMTypeConverter &typeConverter) {
121 auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType());
122 FailureOr<unsigned> maybeMemrefAddrSpace =
123 typeConverter.getMemRefAddressSpace(memRefType);
124 assert(succeeded(maybeMemrefAddrSpace) && "unsupported address space");
125 unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
126 if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
127 allocatedPtr = LLVM::AddrSpaceCastOp::create(
128 rewriter, loc,
129 LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
130 allocatedPtr);
131 return allocatedPtr;
132}
133
134class AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
135 SymbolTableCollection *symbolTables = nullptr;
136
137public:
138 explicit AllocOpLowering(const LLVMTypeConverter &typeConverter,
139 SymbolTableCollection *symbolTables = nullptr,
140 PatternBenefit benefit = 1)
141 : ConvertOpToLLVMPattern<memref::AllocOp>(typeConverter, benefit),
142 symbolTables(symbolTables) {}
143
144 LogicalResult
145 matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
146 ConversionPatternRewriter &rewriter) const override {
147 auto loc = op.getLoc();
148 MemRefType memRefType = op.getType();
149 if (!isConvertibleAndHasIdentityMaps(memRefType))
150 return rewriter.notifyMatchFailure(op, "incompatible memref type");
151
152 // Get or insert alloc function into the module.
153 FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
154 getNotalignedAllocFn(rewriter, getTypeConverter(),
155 op->getParentWithTrait<OpTrait::SymbolTable>(),
156 getIndexType(), symbolTables);
157 if (failed(allocFuncOp))
158 return failure();
159
160 // Get actual sizes of the memref as values: static sizes are constant
161 // values and dynamic sizes are passed to 'alloc' as operands. In case of
162 // zero-dimensional memref, assume a scalar (size 1).
164 SmallVector<Value, 4> strides;
165 Value sizeBytes;
166
167 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
168 rewriter, sizes, strides, sizeBytes, true);
169
170 Value alignment = getAlignment(rewriter, loc, op);
171 if (alignment) {
172 // Adjust the allocation size to consider alignment.
173 sizeBytes = LLVM::AddOp::create(rewriter, loc, sizeBytes, alignment);
174 }
175
176 // Allocate the underlying buffer.
177 Type elementPtrType = this->getElementPtrType(memRefType);
178 assert(elementPtrType && "could not compute element ptr type");
179 auto results =
180 LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(), sizeBytes);
181
182 Value allocatedPtr =
183 castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
184 elementPtrType, *getTypeConverter());
185 Value alignedPtr = allocatedPtr;
186 if (alignment) {
187 // Compute the aligned pointer.
188 Value allocatedInt =
189 LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), allocatedPtr);
190 Value alignmentInt =
191 createAligned(rewriter, loc, allocatedInt, alignment);
192 alignedPtr =
193 LLVM::IntToPtrOp::create(rewriter, loc, elementPtrType, alignmentInt);
194 }
195
196 // Create the MemRef descriptor.
197 auto memRefDescriptor = this->createMemRefDescriptor(
198 loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
199
200 // Return the final value of the descriptor.
201 rewriter.replaceOp(op, {memRefDescriptor});
202 return success();
203 }
204
205 /// Computes the alignment for the given memory allocation op.
206 template <typename OpType>
207 Value getAlignment(ConversionPatternRewriter &rewriter, Location loc,
208 OpType op) const {
209 MemRefType memRefType = op.getType();
210 Value alignment;
211 if (auto alignmentAttr = op.getAlignment()) {
212 Type indexType = getIndexType();
213 alignment =
214 createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr);
215 } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
216 // In the case where no alignment is specified, we may want to override
217 // `malloc's` behavior. `malloc` typically aligns at the size of the
218 // biggest scalar on a target HW. For non-scalars, use the natural
219 // alignment of the LLVM type given by the LLVM DataLayout.
220 alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
221 }
222 return alignment;
223 }
224};
225
226class AlignedAllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
227 SymbolTableCollection *symbolTables = nullptr;
228
229public:
230 explicit AlignedAllocOpLowering(const LLVMTypeConverter &typeConverter,
231 SymbolTableCollection *symbolTables = nullptr,
232 PatternBenefit benefit = 1)
233 : ConvertOpToLLVMPattern<memref::AllocOp>(typeConverter, benefit),
234 symbolTables(symbolTables) {}
235
236 LogicalResult
237 matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
238 ConversionPatternRewriter &rewriter) const override {
239 auto loc = op.getLoc();
240 MemRefType memRefType = op.getType();
241 if (!isConvertibleAndHasIdentityMaps(memRefType))
242 return rewriter.notifyMatchFailure(op, "incompatible memref type");
243
244 // Get or insert alloc function into module.
245 FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
246 getAlignedAllocFn(rewriter, getTypeConverter(),
247 op->getParentWithTrait<OpTrait::SymbolTable>(),
248 getIndexType(), symbolTables);
249 if (failed(allocFuncOp))
250 return failure();
251
252 // Get actual sizes of the memref as values: static sizes are constant
253 // values and dynamic sizes are passed to 'alloc' as operands. In case of
254 // zero-dimensional memref, assume a scalar (size 1).
256 SmallVector<Value, 4> strides;
257 Value sizeBytes;
258
259 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
260 rewriter, sizes, strides, sizeBytes, !false);
261
262 int64_t alignment = alignedAllocationGetAlignment(op, &defaultLayout);
263
264 Value allocAlignment =
265 createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
266
267 // Function aligned_alloc requires size to be a multiple of alignment; we
268 // pad the size to the next multiple if necessary.
269 if (!isMemRefSizeMultipleOf(memRefType, alignment, op, &defaultLayout))
270 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
271
272 Type elementPtrType = this->getElementPtrType(memRefType);
273 auto results =
274 LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(),
275 ValueRange({allocAlignment, sizeBytes}));
276
277 Value ptr =
278 castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
279 elementPtrType, *getTypeConverter());
280
281 // Create the MemRef descriptor.
282 auto memRefDescriptor = this->createMemRefDescriptor(
283 loc, memRefType, ptr, ptr, sizes, strides, rewriter);
284
285 // Return the final value of the descriptor.
286 rewriter.replaceOp(op, {memRefDescriptor});
287 return success();
288 }
289
290 /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
291 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
292
293 /// Computes the alignment for aligned_alloc used to allocate the buffer for
294 /// the memory allocation op.
295 ///
296 /// Aligned_alloc requires the allocation size to be a power of two, and the
297 /// allocation size to be a multiple of the alignment.
298 int64_t alignedAllocationGetAlignment(memref::AllocOp op,
299 const DataLayout *defaultLayout) const {
300 if (std::optional<uint64_t> alignment = op.getAlignment())
301 return *alignment;
302
303 // Whenever we don't have alignment set, we will use an alignment
304 // consistent with the element type; since the allocation size has to be a
305 // power of two, we will bump to the next power of two if it isn't.
306 unsigned eltSizeBytes = getMemRefEltSizeInBytes(
307 getTypeConverter(), op.getType(), op, defaultLayout);
308 return std::max(kMinAlignedAllocAlignment,
309 llvm::PowerOf2Ceil(eltSizeBytes));
310 }
311
312 /// Returns true if the memref size in bytes is known to be a multiple of
313 /// factor.
314 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, Operation *op,
315 const DataLayout *defaultLayout) const {
316 uint64_t sizeDivisor =
317 getMemRefEltSizeInBytes(getTypeConverter(), type, op, defaultLayout);
318 for (unsigned i = 0, e = type.getRank(); i < e; i++) {
319 if (type.isDynamicDim(i))
320 continue;
321 sizeDivisor = sizeDivisor * type.getDimSize(i);
322 }
323 return sizeDivisor % factor == 0;
324 }
325
326private:
327 /// Default layout to use in absence of the corresponding analysis.
328 DataLayout defaultLayout;
329};
330
331struct AllocaOpLowering : public ConvertOpToLLVMPattern<memref::AllocaOp> {
333
334 /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
335 /// is set to null for stack allocations. `accessAlignment` is set if
336 /// alignment is needed post allocation (for eg. in conjunction with malloc).
337 LogicalResult
338 matchAndRewrite(memref::AllocaOp op, OpAdaptor adaptor,
339 ConversionPatternRewriter &rewriter) const override {
340 auto loc = op.getLoc();
341 MemRefType memRefType = op.getType();
342 if (!isConvertibleAndHasIdentityMaps(memRefType))
343 return rewriter.notifyMatchFailure(op, "incompatible memref type");
344
345 // Get actual sizes of the memref as values: static sizes are constant
346 // values and dynamic sizes are passed to 'alloc' as operands. In case of
347 // zero-dimensional memref, assume a scalar (size 1).
349 SmallVector<Value, 4> strides;
350 Value size;
351
352 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
353 rewriter, sizes, strides, size, !true);
354
355 // With alloca, one gets a pointer to the element type right away.
356 // For stack allocations.
357 auto elementType =
358 typeConverter->convertType(op.getType().getElementType());
359 FailureOr<unsigned> maybeAddressSpace =
360 getTypeConverter()->getMemRefAddressSpace(op.getType());
361 assert(succeeded(maybeAddressSpace) && "unsupported address space");
362 unsigned addrSpace = *maybeAddressSpace;
363 auto elementPtrType =
364 LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
365
366 auto allocatedElementPtr =
367 LLVM::AllocaOp::create(rewriter, loc, elementPtrType, elementType, size,
368 op.getAlignment().value_or(0));
369
370 // Create the MemRef descriptor.
371 auto memRefDescriptor = this->createMemRefDescriptor(
372 loc, memRefType, allocatedElementPtr, allocatedElementPtr, sizes,
373 strides, rewriter);
374
375 // Return the final value of the descriptor.
376 rewriter.replaceOp(op, {memRefDescriptor});
377 return success();
378 }
379};
380
381struct AllocaScopeOpLowering
382 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
383 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
384
385 LogicalResult
386 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
387 ConversionPatternRewriter &rewriter) const override {
388 OpBuilder::InsertionGuard guard(rewriter);
389 Location loc = allocaScopeOp.getLoc();
390
391 // Split the current block before the AllocaScopeOp to create the inlining
392 // point.
393 auto *currentBlock = rewriter.getInsertionBlock();
394 auto *remainingOpsBlock =
395 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
396 Block *continueBlock;
397 if (allocaScopeOp.getNumResults() == 0) {
398 continueBlock = remainingOpsBlock;
399 } else {
400 continueBlock = rewriter.createBlock(
401 remainingOpsBlock, allocaScopeOp.getResultTypes(),
402 SmallVector<Location>(allocaScopeOp->getNumResults(),
403 allocaScopeOp.getLoc()));
404 LLVM::BrOp::create(rewriter, loc, ValueRange(), remainingOpsBlock);
405 }
406
407 // Inline body region.
408 Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
409 Block *afterBody = &allocaScopeOp.getBodyRegion().back();
410 rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
411
412 // Save stack and then branch into the body of the region.
413 rewriter.setInsertionPointToEnd(currentBlock);
414 auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType());
415 LLVM::BrOp::create(rewriter, loc, ValueRange(), beforeBody);
416
417 // Replace the alloca_scope return with a branch that jumps out of the body.
418 // Stack restore before leaving the body region.
419 rewriter.setInsertionPointToEnd(afterBody);
420 auto returnOp =
421 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
422 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
423 returnOp, returnOp.getResults(), continueBlock);
424
425 // Insert stack restore before jumping out the body of the region.
426 rewriter.setInsertionPoint(branchOp);
427 LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp);
428
429 // Replace the op with values return from the body region.
430 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
431
432 return success();
433 }
434};
435
436struct AssumeAlignmentOpLowering
437 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
439 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
440 explicit AssumeAlignmentOpLowering(const LLVMTypeConverter &converter)
442
443 LogicalResult
444 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
445 ConversionPatternRewriter &rewriter) const override {
446 Value memref = adaptor.getMemref();
447 unsigned alignment = op.getAlignment();
448 auto loc = op.getLoc();
449
450 auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
451 Value ptr = getStridedElementPtr(rewriter, loc, srcMemRefType, memref,
452 /*indices=*/{});
453
454 // Emit llvm.assume(true) ["align"(memref, alignment)].
455 // This is more direct than ptrtoint-based checks, is explicitly supported,
456 // and works with non-integral address spaces.
457 Value trueCond =
458 LLVM::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(true));
459 Value alignmentConst =
460 createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
461 LLVM::AssumeOp::create(rewriter, loc, trueCond, LLVM::AssumeAlignTag(), ptr,
462 alignmentConst);
463 rewriter.replaceOp(op, memref);
464 return success();
465 }
466};
467
468struct DistinctObjectsOpLowering
469 : public ConvertOpToLLVMPattern<memref::DistinctObjectsOp> {
471 memref::DistinctObjectsOp>::ConvertOpToLLVMPattern;
472 explicit DistinctObjectsOpLowering(const LLVMTypeConverter &converter)
474
475 LogicalResult
476 matchAndRewrite(memref::DistinctObjectsOp op, OpAdaptor adaptor,
477 ConversionPatternRewriter &rewriter) const override {
478 ValueRange operands = adaptor.getOperands();
479 if (operands.size() <= 1) {
480 // Fast path.
481 rewriter.replaceOp(op, operands);
482 return success();
483 }
484
485 Location loc = op.getLoc();
487 for (auto [origOperand, newOperand] :
488 llvm::zip_equal(op.getOperands(), operands)) {
489 auto memrefType = cast<MemRefType>(origOperand.getType());
490 MemRefDescriptor memRefDescriptor(newOperand);
491 Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
492 memrefType);
493 ptrs.push_back(ptr);
494 }
495
496 auto cond =
497 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), 1);
498 // Generate separate_storage assumptions for each pair of pointers.
499 for (auto i : llvm::seq<size_t>(ptrs.size() - 1)) {
500 for (auto j : llvm::seq<size_t>(i + 1, ptrs.size())) {
501 Value ptr1 = ptrs[i];
502 Value ptr2 = ptrs[j];
503 LLVM::AssumeOp::create(rewriter, loc, cond,
504 LLVM::AssumeSeparateStorageTag{}, ptr1, ptr2);
505 }
506 }
507
508 rewriter.replaceOp(op, operands);
509 return success();
510 }
511};
512
513// A `dealloc` is converted into a call to `free` on the underlying data buffer.
514// The memref descriptor being an SSA value, there is no need to clean it up
515// in any way.
516class DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
517 SymbolTableCollection *symbolTables = nullptr;
518
519public:
520 explicit DeallocOpLowering(const LLVMTypeConverter &typeConverter,
521 SymbolTableCollection *symbolTables = nullptr,
522 PatternBenefit benefit = 1)
523 : ConvertOpToLLVMPattern<memref::DeallocOp>(typeConverter, benefit),
524 symbolTables(symbolTables) {}
525
526 LogicalResult
527 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
528 ConversionPatternRewriter &rewriter) const override {
529 // Insert the `free` declaration if it is not already present.
530 FailureOr<LLVM::LLVMFuncOp> freeFunc =
531 getFreeFn(rewriter, getTypeConverter(),
532 op->getParentWithTrait<OpTrait::SymbolTable>(), symbolTables);
533 if (failed(freeFunc))
534 return failure();
535 Value allocatedPtr;
536 if (auto unrankedTy =
537 llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
538 auto elementPtrTy = LLVM::LLVMPointerType::get(
539 rewriter.getContext(), unrankedTy.getMemorySpaceAsInt());
541 rewriter, op.getLoc(),
542 UnrankedMemRefDescriptor(adaptor.getMemref())
543 .memRefDescPtr(rewriter, op.getLoc()),
544 elementPtrTy);
545 } else {
546 allocatedPtr = MemRefDescriptor(adaptor.getMemref())
547 .allocatedPtr(rewriter, op.getLoc());
548 }
549 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc.value(),
550 allocatedPtr);
551 return success();
552 }
553};
554
555// A `dim` is converted to a constant for static sizes and to an access to the
556// size stored in the memref descriptor for dynamic sizes.
557struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
559
560 LogicalResult
561 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
562 ConversionPatternRewriter &rewriter) const override {
563 Type operandType = dimOp.getSource().getType();
564 if (isa<UnrankedMemRefType>(operandType)) {
565 FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
566 operandType, dimOp, adaptor.getOperands(), rewriter);
567 if (failed(extractedSize))
568 return failure();
569 rewriter.replaceOp(dimOp, {*extractedSize});
570 return success();
571 }
572 if (isa<MemRefType>(operandType)) {
573 rewriter.replaceOp(
574 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
575 adaptor.getOperands(), rewriter)});
576 return success();
577 }
578 llvm_unreachable("expected MemRefType or UnrankedMemRefType");
579 }
580
581private:
582 FailureOr<Value>
583 extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
584 OpAdaptor adaptor,
585 ConversionPatternRewriter &rewriter) const {
586 Location loc = dimOp.getLoc();
587
588 auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
589 auto scalarMemRefType =
590 MemRefType::get({}, unrankedMemRefType.getElementType());
591 FailureOr<unsigned> maybeAddressSpace =
592 getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
593 if (failed(maybeAddressSpace)) {
594 dimOp.emitOpError("memref memory space must be convertible to an integer "
595 "address space");
596 return failure();
597 }
598 unsigned addressSpace = *maybeAddressSpace;
599
600 // Extract pointer to the underlying ranked descriptor and bitcast it to a
601 // memref<element_type> descriptor pointer to minimize the number of GEP
602 // operations.
603 UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
604 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
605
606 Type elementType = typeConverter->convertType(scalarMemRefType);
607
608 // Get pointer to offset field of memref<element_type> descriptor.
609 auto indexPtrTy =
610 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
611 Value offsetPtr =
612 LLVM::GEPOp::create(rewriter, loc, indexPtrTy, elementType,
613 underlyingRankedDesc, ArrayRef<LLVM::GEPArg>{0, 2});
614
615 // The size value that we have to extract can be obtained using GEPop with
616 // `dimOp.index() + 1` index argument.
617 Value idxPlusOne = LLVM::AddOp::create(
618 rewriter, loc,
619 createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
620 adaptor.getIndex());
621 Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy,
622 getTypeConverter()->getIndexType(),
623 offsetPtr, idxPlusOne);
624 return LLVM::LoadOp::create(rewriter, loc,
625 getTypeConverter()->getIndexType(), sizePtr)
626 .getResult();
627 }
628
629 std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
630 if (auto idx = dimOp.getConstantIndex())
631 return idx;
632
633 if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
634 return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
635
636 return std::nullopt;
637 }
638
639 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
640 OpAdaptor adaptor,
641 ConversionPatternRewriter &rewriter) const {
642 Location loc = dimOp.getLoc();
643
644 // Take advantage if index is constant.
645 MemRefType memRefType = cast<MemRefType>(operandType);
646 Type indexType = getIndexType();
647 if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
648 int64_t i = *index;
649 if (i >= 0 && i < memRefType.getRank()) {
650 if (memRefType.isDynamicDim(i)) {
651 // extract dynamic size from the memref descriptor.
652 MemRefDescriptor descriptor(adaptor.getSource());
653 return descriptor.size(rewriter, loc, i);
654 }
655 // Use constant for static size.
656 int64_t dimSize = memRefType.getDimSize(i);
657 return createIndexAttrConstant(rewriter, loc, indexType, dimSize);
658 }
659 }
660 Value index = adaptor.getIndex();
661 int64_t rank = memRefType.getRank();
662 MemRefDescriptor memrefDescriptor(adaptor.getSource());
663 return memrefDescriptor.size(rewriter, loc, index, rank);
664 }
665};
666
667/// Common base for load and store operations on MemRefs. Restricts the match
668/// to supported MemRef types. Provides functionality to emit code accessing a
669/// specific element of the underlying data buffer.
670template <typename Derived>
671struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
673 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
674 using Base = LoadStoreOpLowering<Derived>;
675};
676
677/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
678/// retried until it succeeds in atomically storing a new value into memory.
679///
680/// +---------------------------------+
681/// | <code before the AtomicRMWOp> |
682/// | <compute initial %loaded> |
683/// | cf.br loop(%loaded) |
684/// +---------------------------------+
685/// |
686/// -------| |
687/// | v v
688/// | +--------------------------------+
689/// | | loop(%loaded): |
690/// | | <body contents> |
691/// | | %pair = cmpxchg |
692/// | | %ok = %pair[0] |
693/// | | %new = %pair[1] |
694/// | | cf.cond_br %ok, end, loop(%new) |
695/// | +--------------------------------+
696/// | | |
697/// |----------- |
698/// v
699/// +--------------------------------+
700/// | end: |
701/// | <code after the AtomicRMWOp> |
702/// +--------------------------------+
703///
704struct GenericAtomicRMWOpLowering
705 : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
706 using Base::Base;
707
708 LogicalResult
709 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
710 ConversionPatternRewriter &rewriter) const override {
711 auto loc = atomicOp.getLoc();
712 Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
713
714 // Split the block into initial, loop, and ending parts.
715 auto *initBlock = rewriter.getInsertionBlock();
716 auto *loopBlock = rewriter.splitBlock(initBlock, Block::iterator(atomicOp));
717 loopBlock->addArgument(valueType, loc);
718
719 auto *endBlock =
720 rewriter.splitBlock(loopBlock, Block::iterator(atomicOp)++);
721
722 // Compute the loaded value and branch to the loop block.
723 rewriter.setInsertionPointToEnd(initBlock);
724 auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
725 auto dataPtr = getStridedElementPtr(
726 rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
727 Value init = LLVM::LoadOp::create(
728 rewriter, loc, typeConverter->convertType(memRefType.getElementType()),
729 dataPtr);
730 LLVM::BrOp::create(rewriter, loc, init, loopBlock);
731
732 // Prepare the body of the loop block.
733 rewriter.setInsertionPointToStart(loopBlock);
734
735 // Clone the GenericAtomicRMWOp region and extract the result.
736 auto loopArgument = loopBlock->getArgument(0);
737 IRMapping mapping;
738 mapping.map(atomicOp.getCurrentValue(), loopArgument);
739 Block &entryBlock = atomicOp.body().front();
740 for (auto &nestedOp : entryBlock.without_terminator()) {
741 Operation *clone = rewriter.clone(nestedOp, mapping);
742 mapping.map(nestedOp.getResults(), clone->getResults());
743 }
744 Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
745
746 // Prepare the epilog of the loop block.
747 // Append the cmpxchg op to the end of the loop block.
748 auto successOrdering = LLVM::AtomicOrdering::acq_rel;
749 auto failureOrdering = LLVM::AtomicOrdering::monotonic;
750 auto cmpxchg =
751 LLVM::AtomicCmpXchgOp::create(rewriter, loc, dataPtr, loopArgument,
752 result, successOrdering, failureOrdering);
753 // Extract the %new_loaded and %ok values from the pair.
754 Value newLoaded = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 0);
755 Value ok = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 1);
756
757 // Conditionally branch to the end or back to the loop depending on %ok.
758 LLVM::CondBrOp::create(rewriter, loc, ok, endBlock, ArrayRef<Value>(),
759 loopBlock, newLoaded);
760
761 rewriter.setInsertionPointToEnd(endBlock);
762
763 // The 'result' of the atomic_rmw op is the newly loaded value.
764 rewriter.replaceOp(atomicOp, {newLoaded});
765
766 return success();
767 }
768};
769
770/// Returns the LLVM type of the global variable given the memref type `type`.
771static Type
772convertGlobalMemrefTypeToLLVM(MemRefType type,
773 const LLVMTypeConverter &typeConverter) {
774 // LLVM type for a global memref will be a multi-dimension array. For
775 // declarations or uninitialized global memrefs, we can potentially flatten
776 // this to a 1D array. However, for memref.global's with an initial value,
777 // we do not intend to flatten the ElementsAttribute when going from std ->
778 // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
779 Type elementType = typeConverter.convertType(type.getElementType());
780 Type arrayTy = elementType;
781 // Shape has the outermost dim at index 0, so need to walk it backwards
782 for (int64_t dim : llvm::reverse(type.getShape()))
783 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
784 return arrayTy;
785}
786
787/// GlobalMemrefOp is lowered to a LLVM Global Variable.
788class GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<memref::GlobalOp> {
789 SymbolTableCollection *symbolTables = nullptr;
790
791public:
792 explicit GlobalMemrefOpLowering(const LLVMTypeConverter &typeConverter,
793 SymbolTableCollection *symbolTables = nullptr,
794 PatternBenefit benefit = 1)
795 : ConvertOpToLLVMPattern<memref::GlobalOp>(typeConverter, benefit),
796 symbolTables(symbolTables) {}
797
798 LogicalResult
799 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
800 ConversionPatternRewriter &rewriter) const override {
801 MemRefType type = global.getType();
802 if (!isConvertibleAndHasIdentityMaps(type))
803 return failure();
804
805 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
806
807 LLVM::Linkage linkage =
808 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
809 bool isExternal = global.isExternal();
810 bool isUninitialized = global.isUninitialized();
811
812 Attribute initialValue = nullptr;
813 if (!isExternal && !isUninitialized) {
814 auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
815 initialValue = elementsAttr;
816
817 // For scalar memrefs, the global variable created is of the element type,
818 // so unpack the elements attribute to extract the value.
819 if (type.getRank() == 0)
820 initialValue = elementsAttr.getSplatValue<Attribute>();
821 }
822
823 uint64_t alignment = global.getAlignment().value_or(0);
824 FailureOr<unsigned> addressSpace =
825 getTypeConverter()->getMemRefAddressSpace(type);
826 if (failed(addressSpace))
827 return global.emitOpError(
828 "memory space cannot be converted to an integer address space");
829
830 // Remove old operation from symbol table.
831 SymbolTable *symbolTable = nullptr;
832 if (symbolTables) {
833 Operation *symbolTableOp =
835 symbolTable = &symbolTables->getSymbolTable(symbolTableOp);
836 symbolTable->remove(global);
837 }
838
839 // Create new operation.
840 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
841 global, arrayTy, global.getConstant(), linkage, global.getSymName(),
842 initialValue, alignment, *addressSpace);
843
844 // Insert new operation into symbol table.
845 if (symbolTable)
846 symbolTable->insert(newGlobal, rewriter.getInsertionPoint());
847
848 if (!isExternal && isUninitialized) {
849 rewriter.createBlock(&newGlobal.getInitializerRegion());
850 Value undef[] = {
851 LLVM::UndefOp::create(rewriter, newGlobal.getLoc(), arrayTy)};
852 LLVM::ReturnOp::create(rewriter, newGlobal.getLoc(), undef);
853 }
854 return success();
855 }
856};
857
858/// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
859/// the first element stashed into the descriptor. This reuses
860/// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
861struct GetGlobalMemrefOpLowering
862 : public ConvertOpToLLVMPattern<memref::GetGlobalOp> {
863 using ConvertOpToLLVMPattern<memref::GetGlobalOp>::ConvertOpToLLVMPattern;
864
865 /// Buffer "allocation" for memref.get_global op is getting the address of
866 /// the global variable referenced.
867 LogicalResult
868 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
869 ConversionPatternRewriter &rewriter) const override {
870 auto loc = op.getLoc();
871 MemRefType memRefType = op.getType();
872 if (!isConvertibleAndHasIdentityMaps(memRefType))
873 return rewriter.notifyMatchFailure(op, "incompatible memref type");
874
875 // Get actual sizes of the memref as values: static sizes are constant
876 // values and dynamic sizes are passed to 'alloc' as operands. In case of
877 // zero-dimensional memref, assume a scalar (size 1).
879 SmallVector<Value, 4> strides;
880 Value sizeBytes;
881
882 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
883 rewriter, sizes, strides, sizeBytes, !false);
884
885 MemRefType type = cast<MemRefType>(op.getResult().getType());
886
887 // This is called after a type conversion, which would have failed if this
888 // call fails.
889 FailureOr<unsigned> maybeAddressSpace =
890 getTypeConverter()->getMemRefAddressSpace(type);
891 assert(succeeded(maybeAddressSpace) && "unsupported address space");
892 unsigned memSpace = *maybeAddressSpace;
893
894 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
895 auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace);
896 auto addressOf =
897 LLVM::AddressOfOp::create(rewriter, loc, ptrTy, op.getName());
898
899 // Get the address of the first element in the array by creating a GEP with
900 // the address of the GV as the base, and (rank + 1) number of 0 indices.
901 auto gep =
902 LLVM::GEPOp::create(rewriter, loc, ptrTy, arrayTy, addressOf,
903 SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));
904
905 // We do not expect the memref obtained using `memref.get_global` to be
906 // ever deallocated. Set the allocated pointer to be known bad value to
907 // help debug if that ever happens.
908 auto intPtrType = getIntPtrType(memSpace);
909 Value deadBeefConst =
910 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
911 auto deadBeefPtr =
912 LLVM::IntToPtrOp::create(rewriter, loc, ptrTy, deadBeefConst);
913
914 // Both allocated and aligned pointers are same. We could potentially stash
915 // a nullptr for the allocated pointer since we do not expect any dealloc.
916 // Create the MemRef descriptor.
917 auto memRefDescriptor = this->createMemRefDescriptor(
918 loc, memRefType, deadBeefPtr, gep, sizes, strides, rewriter);
919
920 // Return the final value of the descriptor.
921 rewriter.replaceOp(op, {memRefDescriptor});
922 return success();
923 }
924};
925
926// Load operation is lowered to obtaining a pointer to the indexed element
927// and loading it.
928struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
929 using Base::Base;
930
931 LogicalResult
932 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
933 ConversionPatternRewriter &rewriter) const override {
934 auto type = loadOp.getMemRefType();
935
936 // Per memref.load spec, the indices must be in-bounds:
937 // 0 <= idx < dim_size, and additionally all offsets are non-negative,
938 // hence inbounds and nuw are used when lowering to llvm.getelementptr.
939 Value dataPtr = getStridedElementPtr(rewriter, loadOp.getLoc(), type,
940 adaptor.getMemref(),
941 adaptor.getIndices(), kNoWrapFlags);
942 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
943 loadOp, typeConverter->convertType(type.getElementType()), dataPtr,
944 loadOp.getAlignment().value_or(0), false, loadOp.getNontemporal());
945 return success();
946 }
947};
948
949// Store operation is lowered to obtaining a pointer to the indexed element,
950// and storing the given value to it.
951struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
952 using Base::Base;
953
954 LogicalResult
955 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
956 ConversionPatternRewriter &rewriter) const override {
957 auto type = op.getMemRefType();
958
959 // Per memref.store spec, the indices must be in-bounds:
960 // 0 <= idx < dim_size, and additionally all offsets are non-negative,
961 // hence inbounds and nuw are used when lowering to llvm.getelementptr.
962 Value dataPtr =
963 getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getMemref(),
964 adaptor.getIndices(), kNoWrapFlags);
965 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
966 op.getAlignment().value_or(0),
967 false, op.getNontemporal());
968 return success();
969 }
970};
971
972// The prefetch operation is lowered in a way similar to the load operation
973// except that the llvm.prefetch operation is used for replacement.
974struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
975 using Base::Base;
976
977 LogicalResult
978 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
979 ConversionPatternRewriter &rewriter) const override {
980 auto type = prefetchOp.getMemRefType();
981 auto loc = prefetchOp.getLoc();
982
983 Value dataPtr = getStridedElementPtr(
984 rewriter, loc, type, adaptor.getMemref(), adaptor.getIndices());
985
986 // Replace with llvm.prefetch.
987 IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
988 IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
989 IntegerAttr isData =
990 rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache());
991 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
992 localityHint, isData);
993 return success();
994 }
995};
996
997struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
999
1000 LogicalResult
1001 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
1002 ConversionPatternRewriter &rewriter) const override {
1003 Location loc = op.getLoc();
1004 Type operandType = op.getMemref().getType();
1005 if (isa<UnrankedMemRefType>(operandType)) {
1006 UnrankedMemRefDescriptor desc(adaptor.getMemref());
1007 rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
1008 return success();
1009 }
1010 if (auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
1011 Type indexType = getIndexType();
1012 rewriter.replaceOp(op,
1013 {createIndexAttrConstant(rewriter, loc, indexType,
1014 rankedMemRefType.getRank())});
1015 return success();
1016 }
1017 return failure();
1018 }
1019};
1020
1021struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
1023
1024 LogicalResult
1025 matchAndRewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
1026 ConversionPatternRewriter &rewriter) const override {
1027 Type srcType = memRefCastOp.getOperand().getType();
1028 Type dstType = memRefCastOp.getType();
1029
1030 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
1031 // used for type erasure. For now they must preserve underlying element type
1032 // and require source and result type to have the same rank. Therefore,
1033 // perform a sanity check that the underlying structs are the same. Once op
1034 // semantics are relaxed we can revisit.
1035 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
1036 if (typeConverter->convertType(srcType) !=
1037 typeConverter->convertType(dstType))
1038 return failure();
1039
1040 // Unranked to unranked cast is disallowed
1041 if (isa<UnrankedMemRefType>(srcType) && isa<UnrankedMemRefType>(dstType))
1042 return failure();
1043
1044 auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
1045 auto loc = memRefCastOp.getLoc();
1046
1047 // For ranked/ranked case, just keep the original descriptor.
1048 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType)) {
1049 rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
1050 return success();
1051 }
1052
1053 if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
1054 // Casting ranked to unranked memref type
1055 // Set the rank in the destination from the memref type
1056 // Allocate space on the stack and copy the src memref descriptor
1057 // Set the ptr in the destination to the stack space
1058 auto srcMemRefType = cast<MemRefType>(srcType);
1059 int64_t rank = srcMemRefType.getRank();
1060 // ptr = AllocaOp sizeof(MemRefDescriptor)
1061 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
1062 loc, adaptor.getSource(), rewriter);
1063
1064 // rank = ConstantOp srcRank
1065 auto rankVal = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1066 rewriter.getIndexAttr(rank));
1067 // poison = PoisonOp
1068 UnrankedMemRefDescriptor memRefDesc =
1069 UnrankedMemRefDescriptor::poison(rewriter, loc, targetStructType);
1070 // d1 = InsertValueOp poison, rank, 0
1071 memRefDesc.setRank(rewriter, loc, rankVal);
1072 // d2 = InsertValueOp d1, ptr, 1
1073 memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
1074 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
1075
1076 } else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
1077 // Casting from unranked type to ranked.
1078 // The operation is assumed to be doing a correct cast. If the destination
1079 // type mismatches the unranked the type, it is undefined behavior.
1080 UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
1081 // ptr = ExtractValueOp src, 1
1082 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
1083
1084 // struct = LoadOp ptr
1085 auto loadOp = LLVM::LoadOp::create(rewriter, loc, targetStructType, ptr);
1086 rewriter.replaceOp(memRefCastOp, loadOp.getResult());
1087 } else {
1088 llvm_unreachable("Unsupported unranked memref to unranked memref cast");
1089 }
1090
1091 return success();
1092 }
1093};
1094
1095/// Pattern to lower a `memref.copy` to llvm.
1096///
1097/// For memrefs with identity layouts, the copy is lowered to the llvm
1098/// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
1099/// to the generic `MemrefCopyFn`.
1100class MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
1101 SymbolTableCollection *symbolTables = nullptr;
1102
1103public:
1104 explicit MemRefCopyOpLowering(const LLVMTypeConverter &typeConverter,
1105 SymbolTableCollection *symbolTables = nullptr,
1106 PatternBenefit benefit = 1)
1107 : ConvertOpToLLVMPattern<memref::CopyOp>(typeConverter, benefit),
1108 symbolTables(symbolTables) {}
1109
1110 LogicalResult
1111 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
1112 ConversionPatternRewriter &rewriter) const {
1113 auto loc = op.getLoc();
1114 auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
1115
1116 MemRefDescriptor srcDesc(adaptor.getSource());
1117
1118 // Compute number of elements.
1119 Value numElements = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1120 rewriter.getIndexAttr(1));
1121 for (int pos = 0; pos < srcType.getRank(); ++pos) {
1122 auto size = srcDesc.size(rewriter, loc, pos);
1123 numElements = LLVM::MulOp::create(rewriter, loc, numElements, size);
1124 }
1125
1126 // Get element size.
1127 auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
1128 // Compute total.
1129 Value totalSize =
1130 LLVM::MulOp::create(rewriter, loc, numElements, sizeInBytes);
1131
1132 Type elementType = typeConverter->convertType(srcType.getElementType());
1133
1134 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
1135 Value srcOffset = srcDesc.offset(rewriter, loc);
1136 Value srcPtr = LLVM::GEPOp::create(rewriter, loc, srcBasePtr.getType(),
1137 elementType, srcBasePtr, srcOffset);
1138 MemRefDescriptor targetDesc(adaptor.getTarget());
1139 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
1140 Value targetOffset = targetDesc.offset(rewriter, loc);
1141 Value targetPtr =
1142 LLVM::GEPOp::create(rewriter, loc, targetBasePtr.getType(), elementType,
1143 targetBasePtr, targetOffset);
1144 LLVM::MemcpyOp::create(rewriter, loc, targetPtr, srcPtr, totalSize,
1145 /*isVolatile=*/false);
1146 rewriter.eraseOp(op);
1147
1148 return success();
1149 }
1150
1151 LogicalResult
1152 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
1153 ConversionPatternRewriter &rewriter) const {
1154 auto loc = op.getLoc();
1155 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1156 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1157
1158 // First make sure we have an unranked memref descriptor representation.
1159 auto makeUnranked = [&, this](Value ranked, MemRefType type) {
1160 auto rank = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1161 type.getRank());
1162 auto *typeConverter = getTypeConverter();
1163 auto ptr =
1164 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
1165
1166 auto unrankedType =
1167 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
1169 rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, ptr});
1170 };
1171
1172 // Save stack position before promoting descriptors
1173 auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType());
1174
1175 auto srcMemRefType = dyn_cast<MemRefType>(srcType);
1176 Value unrankedSource =
1177 srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
1178 : adaptor.getSource();
1179 auto targetMemRefType = dyn_cast<MemRefType>(targetType);
1180 Value unrankedTarget =
1181 targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
1182 : adaptor.getTarget();
1183
1184 // Now promote the unranked descriptors to the stack.
1185 auto one = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
1186 rewriter.getIndexAttr(1));
1187 auto promote = [&](Value desc) {
1188 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1189 auto allocated =
1190 LLVM::AllocaOp::create(rewriter, loc, ptrType, desc.getType(), one);
1191 LLVM::StoreOp::create(rewriter, loc, desc, allocated);
1192 return allocated;
1193 };
1194
1195 auto sourcePtr = promote(unrankedSource);
1196 auto targetPtr = promote(unrankedTarget);
1197
1198 // Derive size from llvm.getelementptr which will account for any
1199 // potential alignment
1200 auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
1202 rewriter, op->getParentOfType<ModuleOp>(), getIndexType(),
1203 sourcePtr.getType(), symbolTables);
1204 if (failed(copyFn))
1205 return failure();
1206 LLVM::CallOp::create(rewriter, loc, copyFn.value(),
1207 ValueRange{elemSize, sourcePtr, targetPtr});
1208
1209 // Restore stack used for descriptors
1210 LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp);
1211
1212 rewriter.eraseOp(op);
1213
1214 return success();
1215 }
1216
1217 LogicalResult
1218 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
1219 ConversionPatternRewriter &rewriter) const override {
1220 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1221 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1222
1223 auto isContiguousMemrefType = [&](BaseMemRefType type) {
1224 auto memrefType = dyn_cast<mlir::MemRefType>(type);
1225 // We can use memcpy for memrefs if they have an identity layout or are
1226 // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
1227 // special case handled by memrefCopy.
1228 return memrefType &&
1229 (memrefType.getLayout().isIdentity() ||
1230 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
1232 };
1233
1234 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
1235 return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
1236
1237 return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
1238 }
1239};
1240
1241struct MemorySpaceCastOpLowering
1242 : public ConvertOpToLLVMPattern<memref::MemorySpaceCastOp> {
1244 memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern;
1245
1246 LogicalResult
1247 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
1248 ConversionPatternRewriter &rewriter) const override {
1249 Location loc = op.getLoc();
1250
1251 Type resultType = op.getDest().getType();
1252 if (auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
1253 auto resultDescType =
1254 cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR));
1255 Type newPtrType = resultDescType.getBody()[0];
1256
1257 SmallVector<Value> descVals;
1258 MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
1259 descVals);
1260 descVals[0] =
1261 LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[0]);
1262 descVals[1] =
1263 LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[1]);
1264 Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
1265 resultTypeR, descVals);
1266 rewriter.replaceOp(op, result);
1267 return success();
1268 }
1269 if (auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
1270 // Since the type converter won't be doing this for us, get the address
1271 // space.
1272 auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
1273 FailureOr<unsigned> maybeSourceAddrSpace =
1274 getTypeConverter()->getMemRefAddressSpace(sourceType);
1275 if (failed(maybeSourceAddrSpace))
1276 return rewriter.notifyMatchFailure(loc,
1277 "non-integer source address space");
1278 unsigned sourceAddrSpace = *maybeSourceAddrSpace;
1279 FailureOr<unsigned> maybeResultAddrSpace =
1280 getTypeConverter()->getMemRefAddressSpace(resultTypeU);
1281 if (failed(maybeResultAddrSpace))
1282 return rewriter.notifyMatchFailure(loc,
1283 "non-integer result address space");
1284 unsigned resultAddrSpace = *maybeResultAddrSpace;
1285
1286 UnrankedMemRefDescriptor sourceDesc(adaptor.getSource());
1287 Value rank = sourceDesc.rank(rewriter, loc);
1288 Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
1289
1290 // Create and allocate storage for new memref descriptor.
1292 rewriter, loc, typeConverter->convertType(resultTypeU));
1293 result.setRank(rewriter, loc, rank);
1294 Value resultUnderlyingSize = UnrankedMemRefDescriptor::computeSize(
1295 rewriter, loc, *getTypeConverter(), result, resultAddrSpace);
1296 Value resultUnderlyingDesc =
1297 LLVM::AllocaOp::create(rewriter, loc, getPtrType(),
1298 rewriter.getI8Type(), resultUnderlyingSize);
1299 result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
1300
1301 // Copy pointers, performing address space casts.
1302 auto sourceElemPtrType =
1303 LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace);
1304 auto resultElemPtrType =
1305 LLVM::LLVMPointerType::get(rewriter.getContext(), resultAddrSpace);
1306
1307 Value allocatedPtr = sourceDesc.allocatedPtr(
1308 rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
1309 Value alignedPtr =
1310 sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
1311 sourceUnderlyingDesc, sourceElemPtrType);
1312 allocatedPtr = LLVM::AddrSpaceCastOp::create(
1313 rewriter, loc, resultElemPtrType, allocatedPtr);
1314 alignedPtr = LLVM::AddrSpaceCastOp::create(rewriter, loc,
1315 resultElemPtrType, alignedPtr);
1316
1317 result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
1318 resultElemPtrType, allocatedPtr);
1319 result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
1320 resultUnderlyingDesc, resultElemPtrType, alignedPtr);
1321
1322 // Copy all the index-valued operands.
1323 Value sourceIndexVals =
1324 sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1325 sourceUnderlyingDesc, sourceElemPtrType);
1326 Value resultIndexVals =
1327 result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1328 resultUnderlyingDesc, resultElemPtrType);
1329
1330 int64_t bytesToSkip =
1331 2 * llvm::divideCeil(
1332 getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
1333 Value bytesToSkipConst = LLVM::ConstantOp::create(
1334 rewriter, loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
1335 Value copySize =
1336 LLVM::SubOp::create(rewriter, loc, getIndexType(),
1337 resultUnderlyingSize, bytesToSkipConst);
1338 LLVM::MemcpyOp::create(rewriter, loc, resultIndexVals, sourceIndexVals,
1339 copySize, /*isVolatile=*/false);
1340
1341 rewriter.replaceOp(op, ValueRange{result});
1342 return success();
1343 }
1344 return rewriter.notifyMatchFailure(loc, "unexpected memref type");
1345 }
1346};
1347
1348/// Extracts allocated, aligned pointers and offset from a ranked or unranked
1349/// memref type. In unranked case, the fields are extracted from the underlying
1350/// ranked descriptor.
1351static void extractPointersAndOffset(Location loc,
1352 ConversionPatternRewriter &rewriter,
1353 const LLVMTypeConverter &typeConverter,
1354 Value originalOperand,
1355 Value convertedOperand,
1356 Value *allocatedPtr, Value *alignedPtr,
1357 Value *offset = nullptr) {
1358 Type operandType = originalOperand.getType();
1359 if (isa<MemRefType>(operandType)) {
1360 MemRefDescriptor desc(convertedOperand);
1361 *allocatedPtr = desc.allocatedPtr(rewriter, loc);
1362 *alignedPtr = desc.alignedPtr(rewriter, loc);
1363 if (offset != nullptr)
1364 *offset = desc.offset(rewriter, loc);
1365 return;
1366 }
1367
1368 // These will all cause assert()s on unconvertible types.
1369 unsigned memorySpace = *typeConverter.getMemRefAddressSpace(
1370 cast<UnrankedMemRefType>(operandType));
1371 auto elementPtrType =
1372 LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace);
1373
1374 // Extract pointer to the underlying ranked memref descriptor and cast it to
1375 // ElemType**.
1376 UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
1377 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1378
1380 rewriter, loc, underlyingDescPtr, elementPtrType);
1382 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1383 if (offset != nullptr) {
1385 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1386 }
1387}
1388
1389struct MemRefReinterpretCastOpLowering
1390 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1392 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1393
1394 LogicalResult
1395 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1396 ConversionPatternRewriter &rewriter) const override {
1397 Type srcType = castOp.getSource().getType();
1398
1399 Value descriptor;
1400 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1401 adaptor, &descriptor)))
1402 return failure();
1403 rewriter.replaceOp(castOp, {descriptor});
1404 return success();
1405 }
1406
1407private:
1408 LogicalResult convertSourceMemRefToDescriptor(
1409 ConversionPatternRewriter &rewriter, Type srcType,
1410 memref::ReinterpretCastOp castOp,
1411 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1412 MemRefType targetMemRefType =
1413 cast<MemRefType>(castOp.getResult().getType());
1414 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1415 typeConverter->convertType(targetMemRefType));
1416 if (!llvmTargetDescriptorTy)
1417 return failure();
1418
1419 // Create descriptor.
1420 Location loc = castOp.getLoc();
1421 auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1422
1423 // Set allocated and aligned pointers.
1424 Value allocatedPtr, alignedPtr;
1425 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1426 castOp.getSource(), adaptor.getSource(),
1427 &allocatedPtr, &alignedPtr);
1428 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1429 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1430
1431 // Set offset.
1432 if (castOp.isDynamicOffset(0))
1433 desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1434 else
1435 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1436
1437 // Set sizes and strides.
1438 unsigned dynSizeId = 0;
1439 unsigned dynStrideId = 0;
1440 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1441 if (castOp.isDynamicSize(i))
1442 desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1443 else
1444 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1445
1446 if (castOp.isDynamicStride(i))
1447 desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1448 else
1449 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1450 }
1451 *descriptor = desc;
1452 return success();
1453 }
1454};
1455
1456struct MemRefReshapeOpLowering
1457 : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1458 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
1459
1460 LogicalResult
1461 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1462 ConversionPatternRewriter &rewriter) const override {
1463 Type srcType = reshapeOp.getSource().getType();
1464
1465 Value descriptor;
1466 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1467 adaptor, &descriptor)))
1468 return failure();
1469 rewriter.replaceOp(reshapeOp, {descriptor});
1470 return success();
1471 }
1472
1473private:
1474 LogicalResult
1475 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1476 Type srcType, memref::ReshapeOp reshapeOp,
1477 memref::ReshapeOp::Adaptor adaptor,
1478 Value *descriptor) const {
1479 auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
1480 if (shapeMemRefType.hasStaticShape()) {
1481 MemRefType targetMemRefType =
1482 cast<MemRefType>(reshapeOp.getResult().getType());
1483 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1484 typeConverter->convertType(targetMemRefType));
1485 if (!llvmTargetDescriptorTy)
1486 return failure();
1487
1488 // Create descriptor.
1489 Location loc = reshapeOp.getLoc();
1490 auto desc =
1491 MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1492
1493 // Set allocated and aligned pointers.
1494 Value allocatedPtr, alignedPtr;
1495 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1496 reshapeOp.getSource(), adaptor.getSource(),
1497 &allocatedPtr, &alignedPtr);
1498 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1499 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1500
1501 // Extract the offset and strides from the type.
1502 int64_t offset;
1503 SmallVector<int64_t> strides;
1504 if (failed(targetMemRefType.getStridesAndOffset(strides, offset)))
1505 return rewriter.notifyMatchFailure(
1506 reshapeOp, "failed to get stride and offset exprs");
1507
1508 if (!isStaticStrideOrOffset(offset))
1509 return rewriter.notifyMatchFailure(reshapeOp,
1510 "dynamic offset is unsupported");
1511
1512 desc.setConstantOffset(rewriter, loc, offset);
1513
1514 assert(targetMemRefType.getLayout().isIdentity() &&
1515 "Identity layout map is a precondition of a valid reshape op");
1516
1517 Type indexType = getIndexType();
1518 Value stride = nullptr;
1519 int64_t targetRank = targetMemRefType.getRank();
1520 for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1521 if (ShapedType::isStatic(strides[i])) {
1522 // If the stride for this dimension is dynamic, then use the product
1523 // of the sizes of the inner dimensions.
1524 stride =
1525 createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
1526 } else if (!stride) {
1527 // `stride` is null only in the first iteration of the loop. However,
1528 // since the target memref has an identity layout, we can safely set
1529 // the innermost stride to 1.
1530 stride = createIndexAttrConstant(rewriter, loc, indexType, 1);
1531 }
1532
1533 Value dimSize;
1534 // If the size of this dimension is dynamic, then load it at runtime
1535 // from the shape operand.
1536 if (!targetMemRefType.isDynamicDim(i)) {
1537 dimSize = createIndexAttrConstant(rewriter, loc, indexType,
1538 targetMemRefType.getDimSize(i));
1539 } else {
1540 Value shapeOp = reshapeOp.getShape();
1541 Value index = createIndexAttrConstant(rewriter, loc, indexType, i);
1542 dimSize = memref::LoadOp::create(rewriter, loc, shapeOp, index);
1543 Type indexType = getIndexType();
1544 if (dimSize.getType() != indexType)
1545 dimSize = typeConverter->materializeTargetConversion(
1546 rewriter, loc, indexType, dimSize);
1547 assert(dimSize && "Invalid memref element type");
1548 }
1549
1550 desc.setSize(rewriter, loc, i, dimSize);
1551 desc.setStride(rewriter, loc, i, stride);
1552
1553 // Prepare the stride value for the next dimension.
1554 stride = LLVM::MulOp::create(rewriter, loc, stride, dimSize);
1555 }
1556
1557 *descriptor = desc;
1558 return success();
1559 }
1560
1561 // The shape is a rank-1 tensor with unknown length.
1562 Location loc = reshapeOp.getLoc();
1563 MemRefDescriptor shapeDesc(adaptor.getShape());
1564 Value resultRank = shapeDesc.size(rewriter, loc, 0);
1565
1566 // Extract address space and element type.
1567 auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1568 unsigned addressSpace =
1569 *getTypeConverter()->getMemRefAddressSpace(targetType);
1570
1571 // Create the unranked memref descriptor that holds the ranked one. The
1572 // inner descriptor is allocated on stack.
1573 auto targetDesc = UnrankedMemRefDescriptor::poison(
1574 rewriter, loc, typeConverter->convertType(targetType));
1575 targetDesc.setRank(rewriter, loc, resultRank);
1577 rewriter, loc, *getTypeConverter(), targetDesc, addressSpace);
1578 Value underlyingDescPtr = LLVM::AllocaOp::create(
1579 rewriter, loc, getPtrType(), IntegerType::get(getContext(), 8),
1580 allocationSize);
1581 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1582
1583 // Extract pointers and offset from the source memref.
1584 Value allocatedPtr, alignedPtr, offset;
1585 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1586 reshapeOp.getSource(), adaptor.getSource(),
1587 &allocatedPtr, &alignedPtr, &offset);
1588
1589 // Set pointers and offset.
1590 auto elementPtrType =
1591 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
1592
1593 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1594 elementPtrType, allocatedPtr);
1595 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1596 underlyingDescPtr, elementPtrType,
1597 alignedPtr);
1598 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1599 underlyingDescPtr, elementPtrType,
1600 offset);
1601
1602 // Use the offset pointer as base for further addressing. Copy over the new
1603 // shape and compute strides. For this, we create a loop from rank-1 to 0.
1605 rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
1607 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1608 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1609 Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1);
1610 Value resultRankMinusOne =
1611 LLVM::SubOp::create(rewriter, loc, resultRank, oneIndex);
1612
1613 Block *initBlock = rewriter.getInsertionBlock();
1614 Type indexType = getTypeConverter()->getIndexType();
1615 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1616
1617 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1618 {indexType, indexType}, {loc, loc});
1619
1620 // Move the remaining initBlock ops to condBlock.
1621 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1622 rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1623
1624 rewriter.setInsertionPointToEnd(initBlock);
1625 LLVM::BrOp::create(rewriter, loc,
1626 ValueRange({resultRankMinusOne, oneIndex}), condBlock);
1627 rewriter.setInsertionPointToStart(condBlock);
1628 Value indexArg = condBlock->getArgument(0);
1629 Value strideArg = condBlock->getArgument(1);
1630
1631 Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0);
1632 Value pred = LLVM::ICmpOp::create(
1633 rewriter, loc, IntegerType::get(rewriter.getContext(), 1),
1634 LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1635
1636 Block *bodyBlock =
1637 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1638 rewriter.setInsertionPointToStart(bodyBlock);
1639
1640 // Copy size from shape to descriptor.
1641 auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1642 Value sizeLoadGep = LLVM::GEPOp::create(
1643 rewriter, loc, llvmIndexPtrType,
1644 typeConverter->convertType(shapeMemRefType.getElementType()),
1645 shapeOperandPtr, indexArg);
1646 Value size = LLVM::LoadOp::create(rewriter, loc, indexType, sizeLoadGep);
1647 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1648 targetSizesBase, indexArg, size);
1649
1650 // Write stride value and compute next one.
1651 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1652 targetStridesBase, indexArg, strideArg);
1653 Value nextStride = LLVM::MulOp::create(rewriter, loc, strideArg, size);
1654
1655 // Decrement loop counter and branch back.
1656 Value decrement = LLVM::SubOp::create(rewriter, loc, indexArg, oneIndex);
1657 LLVM::BrOp::create(rewriter, loc, ValueRange({decrement, nextStride}),
1658 condBlock);
1659
1660 Block *remainder =
1661 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1662
1663 // Hook up the cond exit to the remainder.
1664 rewriter.setInsertionPointToEnd(condBlock);
1665 LLVM::CondBrOp::create(rewriter, loc, pred, bodyBlock, ValueRange(),
1666 remainder, ValueRange());
1667
1668 // Reset position to beginning of new remainder block.
1669 rewriter.setInsertionPointToStart(remainder);
1670
1671 *descriptor = targetDesc;
1672 return success();
1673 }
1674};
1675
1676/// RessociatingReshapeOp must be expanded before we reach this stage.
1677/// Report that information.
1678template <typename ReshapeOp>
1679class ReassociatingReshapeOpConversion
1680 : public ConvertOpToLLVMPattern<ReshapeOp> {
1681public:
1683 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1684
1685 LogicalResult
1686 matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1687 ConversionPatternRewriter &rewriter) const override {
1688 return rewriter.notifyMatchFailure(
1689 reshapeOp,
1690 "reassociation operations should have been expanded beforehand");
1691 }
1692};
1693
1694/// Subviews must be expanded before we reach this stage.
1695/// Report that information.
1696struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1697 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1698
1699 LogicalResult
1700 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1701 ConversionPatternRewriter &rewriter) const override {
1702 return rewriter.notifyMatchFailure(
1703 subViewOp, "subview operations should have been expanded beforehand");
1704 }
1705};
1706
1707/// Conversion pattern that transforms a transpose op into:
1708/// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
1709/// 2. A load of the ViewDescriptor from the pointer allocated in 1.
1710/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1711/// and stride. Size and stride are permutations of the original values.
1712/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1713/// The transpose op is replaced by the alloca'ed pointer.
1714class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1715public:
1716 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1717
1718 LogicalResult
1719 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1720 ConversionPatternRewriter &rewriter) const override {
1721 auto loc = transposeOp.getLoc();
1722 MemRefDescriptor viewMemRef(adaptor.getIn());
1723
1724 // No permutation, early exit.
1725 if (transposeOp.getPermutation().isIdentity())
1726 return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1727
1728 auto targetMemRef = MemRefDescriptor::poison(
1729 rewriter, loc,
1730 typeConverter->convertType(transposeOp.getIn().getType()));
1731
1732 // Copy the base and aligned pointers from the old descriptor to the new
1733 // one.
1734 targetMemRef.setAllocatedPtr(rewriter, loc,
1735 viewMemRef.allocatedPtr(rewriter, loc));
1736 targetMemRef.setAlignedPtr(rewriter, loc,
1737 viewMemRef.alignedPtr(rewriter, loc));
1738
1739 // Copy the offset pointer from the old descriptor to the new one.
1740 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1741
1742 // Iterate over the dimensions and apply size/stride permutation:
1743 // When enumerating the results of the permutation map, the enumeration
1744 // index is the index into the target dimensions and the DimExpr points to
1745 // the dimension of the source memref.
1746 for (const auto &en :
1747 llvm::enumerate(transposeOp.getPermutation().getResults())) {
1748 int targetPos = en.index();
1749 int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
1750 targetMemRef.setSize(rewriter, loc, targetPos,
1751 viewMemRef.size(rewriter, loc, sourcePos));
1752 targetMemRef.setStride(rewriter, loc, targetPos,
1753 viewMemRef.stride(rewriter, loc, sourcePos));
1754 }
1755
1756 rewriter.replaceOp(transposeOp, {targetMemRef});
1757 return success();
1758 }
1759};
1760
1761/// Conversion pattern that transforms an op into:
1762/// 1. An `llvm.mlir.undef` operation to create a memref descriptor
1763/// 2. Updates to the descriptor to introduce the data ptr, offset, size
1764/// and stride.
1765/// The view op is replaced by the descriptor.
1766struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1768
1769 // Build and return the value for the idx^th shape dimension, either by
1770 // returning the constant shape dimension or counting the proper dynamic size.
1771 Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1772 ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx,
1773 Type indexType) const {
1774 assert(idx < shape.size());
1775 if (ShapedType::isStatic(shape[idx]))
1776 return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]);
1777 // Count the number of dynamic dims in range [0, idx]
1778 unsigned nDynamic =
1779 llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1780 return dynamicSizes[nDynamic];
1781 }
1782
1783 // Build and return the idx^th stride, either by returning the constant stride
1784 // or by computing the dynamic stride from the current `runningStride` and
1785 // `nextSize`. The caller should keep a running stride and update it with the
1786 // result returned by this function.
1787 Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1788 ArrayRef<int64_t> strides, Value nextSize,
1789 Value runningStride, unsigned idx, Type indexType) const {
1790 assert(idx < strides.size());
1791 if (ShapedType::isStatic(strides[idx]))
1792 return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]);
1793 if (nextSize)
1794 return runningStride
1795 ? LLVM::MulOp::create(rewriter, loc, runningStride, nextSize)
1796 : nextSize;
1797 assert(!runningStride);
1798 return createIndexAttrConstant(rewriter, loc, indexType, 1);
1799 }
1800
1801 LogicalResult
1802 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1803 ConversionPatternRewriter &rewriter) const override {
1804 auto loc = viewOp.getLoc();
1805
1806 auto viewMemRefType = viewOp.getType();
1807 auto targetElementTy =
1808 typeConverter->convertType(viewMemRefType.getElementType());
1809 auto targetDescTy = typeConverter->convertType(viewMemRefType);
1810 if (!targetDescTy || !targetElementTy ||
1811 !LLVM::isCompatibleType(targetElementTy) ||
1812 !LLVM::isCompatibleType(targetDescTy))
1813 return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1814 failure();
1815
1816 int64_t offset;
1818 auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset);
1819 if (failed(successStrides))
1820 return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1821 assert(offset == 0 && "expected offset to be 0");
1822
1823 // Target memref must be contiguous in memory (innermost stride is 1), or
1824 // empty (special case when at least one of the memref dimensions is 0).
1825 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1826 return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1827 failure();
1828
1829 // Create the descriptor.
1830 MemRefDescriptor sourceMemRef(adaptor.getSource());
1831 auto targetMemRef = MemRefDescriptor::poison(rewriter, loc, targetDescTy);
1832
1833 // Field 1: Copy the allocated pointer, used for malloc/free.
1834 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1835 auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1836 targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
1837
1838 // Field 2: Copy the actual aligned pointer to payload.
1839 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1840 alignedPtr = LLVM::GEPOp::create(
1841 rewriter, loc, alignedPtr.getType(),
1842 typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
1843 adaptor.getByteShift());
1844
1845 targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
1846
1847 Type indexType = getIndexType();
1848 // Field 3: The offset in the resulting type must be 0. This is
1849 // because of the type change: an offset on srcType* may not be
1850 // expressible as an offset on dstType*.
1851 targetMemRef.setOffset(
1852 rewriter, loc,
1853 createIndexAttrConstant(rewriter, loc, indexType, offset));
1854
1855 // Early exit for 0-D corner case.
1856 if (viewMemRefType.getRank() == 0)
1857 return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1858
1859 // Fields 4 and 5: Update sizes and strides.
1860 Value stride = nullptr, nextSize = nullptr;
1861 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1862 // Update size.
1863 Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1864 adaptor.getSizes(), i, indexType);
1865 targetMemRef.setSize(rewriter, loc, i, size);
1866 // Update stride.
1867 stride =
1868 getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
1869 targetMemRef.setStride(rewriter, loc, i, stride);
1870 nextSize = size;
1871 }
1872
1873 rewriter.replaceOp(viewOp, {targetMemRef});
1874 return success();
1875 }
1876};
1877
1878//===----------------------------------------------------------------------===//
1879// AtomicRMWOpLowering
1880//===----------------------------------------------------------------------===//
1881
1882/// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1883/// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1884static std::optional<LLVM::AtomicBinOp>
1885matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1886 switch (atomicOp.getKind()) {
1887 case arith::AtomicRMWKind::addf:
1888 return LLVM::AtomicBinOp::fadd;
1889 case arith::AtomicRMWKind::addi:
1890 return LLVM::AtomicBinOp::add;
1891 case arith::AtomicRMWKind::assign:
1892 return LLVM::AtomicBinOp::xchg;
1893 case arith::AtomicRMWKind::maximumf:
1894 // TODO: remove this by end of 2025.
1895 LDBG() << "the lowering of memref.atomicrmw maximumf changed "
1896 "from fmax to fmaximum, expect more NaNs";
1897 return LLVM::AtomicBinOp::fmaximum;
1898 case arith::AtomicRMWKind::maxnumf:
1899 return LLVM::AtomicBinOp::fmax;
1900 case arith::AtomicRMWKind::maxs:
1901 return LLVM::AtomicBinOp::max;
1902 case arith::AtomicRMWKind::maxu:
1903 return LLVM::AtomicBinOp::umax;
1904 case arith::AtomicRMWKind::minimumf:
1905 // TODO: remove this by end of 2025.
1906 LDBG() << "the lowering of memref.atomicrmw minimum changed "
1907 "from fmin to fminimum, expect more NaNs";
1908 return LLVM::AtomicBinOp::fminimum;
1909 case arith::AtomicRMWKind::minnumf:
1910 return LLVM::AtomicBinOp::fmin;
1911 case arith::AtomicRMWKind::mins:
1912 return LLVM::AtomicBinOp::min;
1913 case arith::AtomicRMWKind::minu:
1914 return LLVM::AtomicBinOp::umin;
1915 case arith::AtomicRMWKind::ori:
1916 return LLVM::AtomicBinOp::_or;
1917 case arith::AtomicRMWKind::xori:
1918 return LLVM::AtomicBinOp::_xor;
1919 case arith::AtomicRMWKind::andi:
1920 return LLVM::AtomicBinOp::_and;
1921 default:
1922 return std::nullopt;
1923 }
1924 llvm_unreachable("Invalid AtomicRMWKind");
1925}
1926
1927struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1928 using Base::Base;
1929
1930 LogicalResult
1931 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1932 ConversionPatternRewriter &rewriter) const override {
1933 auto maybeKind = matchSimpleAtomicOp(atomicOp);
1934 if (!maybeKind)
1935 return failure();
1936 auto memRefType = atomicOp.getMemRefType();
1937 SmallVector<int64_t> strides;
1938 int64_t offset;
1939 if (failed(memRefType.getStridesAndOffset(strides, offset)))
1940 return failure();
1941 auto dataPtr =
1942 getStridedElementPtr(rewriter, atomicOp.getLoc(), memRefType,
1943 adaptor.getMemref(), adaptor.getIndices());
1944 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1945 atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1946 LLVM::AtomicOrdering::acq_rel);
1947 return success();
1948 }
1949};
1950
1951/// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index.
1952class ConvertExtractAlignedPointerAsIndex
1953 : public ConvertOpToLLVMPattern<memref::ExtractAlignedPointerAsIndexOp> {
1954public:
1956 memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern;
1957
1958 LogicalResult
1959 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1960 OpAdaptor adaptor,
1961 ConversionPatternRewriter &rewriter) const override {
1962 BaseMemRefType sourceTy = extractOp.getSource().getType();
1963
1964 Value alignedPtr;
1965 if (sourceTy.hasRank()) {
1966 MemRefDescriptor desc(adaptor.getSource());
1967 alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc());
1968 } else {
1969 auto elementPtrTy = LLVM::LLVMPointerType::get(
1970 rewriter.getContext(), sourceTy.getMemorySpaceAsInt());
1971
1972 UnrankedMemRefDescriptor desc(adaptor.getSource());
1973 Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc());
1974
1976 rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
1977 elementPtrTy);
1978 }
1979
1980 rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
1981 extractOp, getTypeConverter()->getIndexType(), alignedPtr);
1982 return success();
1983 }
1984};
1985
1986/// Materialize the MemRef descriptor represented by the results of
1987/// ExtractStridedMetadataOp.
1988class ExtractStridedMetadataOpLowering
1989 : public ConvertOpToLLVMPattern<memref::ExtractStridedMetadataOp> {
1990public:
1992 memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern;
1993
1994 LogicalResult
1995 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1996 OpAdaptor adaptor,
1997 ConversionPatternRewriter &rewriter) const override {
1998
1999 if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
2000 return failure();
2001
2002 // Create the descriptor.
2003 MemRefDescriptor sourceMemRef(adaptor.getSource());
2004 Location loc = extractStridedMetadataOp.getLoc();
2005 Value source = extractStridedMetadataOp.getSource();
2006
2007 auto sourceMemRefType = cast<MemRefType>(source.getType());
2008 int64_t rank = sourceMemRefType.getRank();
2009 SmallVector<Value> results;
2010 results.reserve(2 + rank * 2);
2011
2012 // Base buffer.
2013 Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc);
2014 Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
2016 rewriter, loc, *getTypeConverter(),
2017 cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
2018 baseBuffer, alignedBuffer);
2019 results.push_back((Value)dstMemRef);
2020
2021 // Offset.
2022 results.push_back(sourceMemRef.offset(rewriter, loc));
2023
2024 // Sizes.
2025 for (unsigned i = 0; i < rank; ++i)
2026 results.push_back(sourceMemRef.size(rewriter, loc, i));
2027 // Strides.
2028 for (unsigned i = 0; i < rank; ++i)
2029 results.push_back(sourceMemRef.stride(rewriter, loc, i));
2030
2031 rewriter.replaceOp(extractStridedMetadataOp, results);
2032 return success();
2033 }
2034};
2035
2036} // namespace
2037
2039 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
2040 SymbolTableCollection *symbolTables) {
2041 // clang-format off
2042 patterns.add<
2043 AllocaOpLowering,
2044 AllocaScopeOpLowering,
2045 AssumeAlignmentOpLowering,
2046 AtomicRMWOpLowering,
2047 ConvertExtractAlignedPointerAsIndex,
2048 DimOpLowering,
2049 DistinctObjectsOpLowering,
2050 ExtractStridedMetadataOpLowering,
2051 GenericAtomicRMWOpLowering,
2052 GetGlobalMemrefOpLowering,
2053 LoadOpLowering,
2054 MemRefCastOpLowering,
2055 MemRefReinterpretCastOpLowering,
2056 MemRefReshapeOpLowering,
2057 MemorySpaceCastOpLowering,
2058 PrefetchOpLowering,
2059 RankOpLowering,
2060 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
2061 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
2062 StoreOpLowering,
2063 SubViewOpLowering,
2065 ViewOpLowering>(converter);
2066 // clang-format on
2067 patterns.add<GlobalMemrefOpLowering, MemRefCopyOpLowering>(converter,
2068 symbolTables);
2069 auto allocLowering = converter.getOptions().allocLowering;
2071 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter,
2072 symbolTables);
2073 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
2074 patterns.add<AllocOpLowering, DeallocOpLowering>(converter, symbolTables);
2075}
2076
2077namespace {
2078struct FinalizeMemRefToLLVMConversionPass
2080 FinalizeMemRefToLLVMConversionPass> {
2081 using FinalizeMemRefToLLVMConversionPassBase::
2082 FinalizeMemRefToLLVMConversionPassBase;
2083
2084 void runOnOperation() override {
2085 Operation *op = getOperation();
2086 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
2088 dataLayoutAnalysis.getAtOrAbove(op));
2089 options.allocLowering =
2092
2093 options.useGenericFunctions = useGenericFunctions;
2094
2095 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
2096 options.overrideIndexBitwidth(indexBitwidth);
2097
2098 LLVMTypeConverter typeConverter(&getContext(), options,
2099 &dataLayoutAnalysis);
2101 SymbolTableCollection symbolTables;
2103 &symbolTables);
2105 target.addLegalOp<func::FuncOp>();
2106 if (failed(applyPartialConversion(op, target, std::move(patterns))))
2107 signalPassFailure();
2108 }
2109};
2110
2111/// Implement the interface to convert MemRef to LLVM.
2112struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
2114 void loadDependentDialects(MLIRContext *context) const final {
2115 context->loadDialect<LLVM::LLVMDialect>();
2116 }
2117
2118 /// Hook for derived dialect interface to provide conversion patterns
2119 /// and mark dialect legal for the conversion target.
2120 void populateConvertToLLVMConversionPatterns(
2121 ConversionTarget &target, LLVMTypeConverter &typeConverter,
2122 RewritePatternSet &patterns) const final {
2124 }
2125};
2126
2127} // namespace
2128
2130 registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
2131 dialect->addInterfaces<MemRefToLLVMDialectInterface>();
2132 });
2133}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags
static llvm::Value * getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:140
BlockArgument getArgument(unsigned i)
Definition Block.h:129
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
Operation & front()
Definition Block.h:153
Operation & back()
Definition Block.h:152
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgListType getArguments()
Definition Block.h:87
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:212
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:209
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:215
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) const
Promote the LLVM struct representation of one MemRef descriptor to stack and use pointer to struct to...
const LowerToLLVMOptions & getOptions() const
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
const DataLayoutAnalysis * getDataLayoutAnalysis() const
Returns the data layout analysis to query during conversion.
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Options to control the LLVM lowering.
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value bufferPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type)
Builds IR for getting the start address of the buffer represented by this memref: memref....
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
Value allocatedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the allocated pointer from the descriptor.
Value offset(OpBuilder &builder, Location loc)
Builds IR extracting the offset from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:248
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class represents a collection of SymbolTables.
virtual SymbolTable & getSymbolTable(Operation *op)
Lookup, or create, a symbol table for an operation.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
static void setOffset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset)
Builds IR inserting the offset into the descriptor.
static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
TODO: The following accessors don't take alignment rules between elements of the descriptor struct in...
static Value computeSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, UnrankedMemRefDescriptor desc, unsigned addressSpace)
Builds and returns IR computing the size in bytes (suitable for opaque allocation).
void setRank(OpBuilder &builder, Location loc, Value value)
Builds IR setting the rank in the descriptor.
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr)
Builds IR inserting the allocated pointer into the descriptor.
static void setSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value index, Value size)
Builds IR inserting the size[index] into the descriptor.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
static void setAlignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr)
Builds IR inserting the aligned pointer into the descriptor.
Value rank(OpBuilder &builder, Location loc) const
Builds IR extracting the rank from the descriptor.
static Value offsetBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR for getting the pointer to the offset's location.
static Value offset(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the offset from the descriptor.
static Value strideBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank)
Builds IR extracting the pointer to the first element of the stride array.
void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value)
Builds IR setting ranked memref descriptor ptr.
static void setStride(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride)
Builds IR inserting the stride[index] into the descriptor.
static Value sizeBasePtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the pointer to the first element of the size array.
static Value alignedPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType)
Builds IR extracting the aligned pointer from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType, Type unrankedDescriptorType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
Include the generated interface declarations.
void registerConvertMemRefToLLVMInterface(DialectRegistry &registry)
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void populateFinalizeMemRefToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect a set of patterns to convert memory-related operations from the MemRef dialect to the LLVM di...
const FrozenRewritePatternSet & patterns
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.