MLIR 23.0.0git
EmulateNarrowType.cpp
Go to the documentation of this file.
1//===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++
2//-*-===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9
19#include "mlir/IR/Builders.h"
23#include "llvm/Support/FormatVariadic.h"
24#include "llvm/Support/MathExtras.h"
25#include <cassert>
26#include <type_traits>
27
28using namespace mlir;
29
30//===----------------------------------------------------------------------===//
31// Utility functions
32//===----------------------------------------------------------------------===//
33
34/// Converts a memref::ReinterpretCastOp to the converted type. The result
35/// MemRefType of the old op must have a rank and stride of 1, with static
36/// offset and size. The number of bits in the offset must evenly divide the
37/// bitwidth of the new converted type.
38static LogicalResult
39convertCastingOp(ConversionPatternRewriter &rewriter,
40 memref::ReinterpretCastOp::Adaptor adaptor,
41 memref::ReinterpretCastOp op, MemRefType newTy) {
42 auto convertedElementType = newTy.getElementType();
43 auto oldElementType = op.getType().getElementType();
44 int srcBits = oldElementType.getIntOrFloatBitWidth();
45 int dstBits = convertedElementType.getIntOrFloatBitWidth();
46 if (dstBits % srcBits != 0) {
47 return rewriter.notifyMatchFailure(op,
48 "only dstBits % srcBits == 0 supported");
49 }
50
51 // Only support stride of 1.
52 if (llvm::any_of(op.getStaticStrides(),
53 [](int64_t stride) { return stride != 1; })) {
54 return rewriter.notifyMatchFailure(op->getLoc(),
55 "stride != 1 is not supported");
56 }
57
58 auto sizes = op.getStaticSizes();
59 int64_t offset = op.getStaticOffset(0);
60 // Only support static sizes and offsets.
61 if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
62 offset == ShapedType::kDynamic) {
63 return rewriter.notifyMatchFailure(
64 op, "dynamic size or offset is not supported");
65 }
66
67 int elementsPerByte = dstBits / srcBits;
68 if (offset % elementsPerByte != 0) {
69 return rewriter.notifyMatchFailure(
70 op, "offset not multiple of elementsPerByte is not supported");
71 }
72
74 if (!sizes.empty())
75 size.push_back(llvm::divideCeilSigned(sizes[0], elementsPerByte));
76 offset = offset / elementsPerByte;
77
78 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
79 op, newTy, adaptor.getSource(), offset, size, op.getStaticStrides());
80 return success();
81}
82
83/// When data is loaded/stored in `targetBits` granularity, but is used in
84/// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
85/// treated as an array of elements of width `sourceBits`.
86/// Return the bit offset of the value at position `srcIdx`. For example, if
87/// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
88/// located at (x % 2) * 4. Because there are two elements in one i8, and one
89/// element has 4 bits.
91 int sourceBits, int targetBits,
92 OpBuilder &builder) {
93 assert(targetBits % sourceBits == 0);
94 AffineExpr s0;
95 bindSymbols(builder.getContext(), s0);
96 int scaleFactor = targetBits / sourceBits;
97 AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
98 OpFoldResult offsetVal =
99 affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
100 Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
101 IntegerType dstType = builder.getIntegerType(targetBits);
102 return arith::IndexCastOp::create(builder, loc, dstType, bitOffset);
103}
104
105/// When writing a subbyte size, masked bitwise operations are used to only
106/// modify the relevant bits. This function returns an and mask for clearing
107/// the destination bits in a subbyte write. E.g., when writing to the second
108/// i4 in an i32, 0xFFFFFF0F is created.
109static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices,
110 int64_t srcBits, int64_t dstBits,
111 Value bitwidthOffset, OpBuilder &builder) {
112 auto dstIntegerType = builder.getIntegerType(dstBits);
113 auto maskRightAlignedAttr =
114 builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
115 Value maskRightAligned = arith::ConstantOp::create(
116 builder, loc, dstIntegerType, maskRightAlignedAttr);
117 Value writeMaskInverse =
118 arith::ShLIOp::create(builder, loc, maskRightAligned, bitwidthOffset);
119 auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
120 Value flipVal =
121 arith::ConstantOp::create(builder, loc, dstIntegerType, flipValAttr);
122 return arith::XOrIOp::create(builder, loc, writeMaskInverse, flipVal);
123}
124
125/// Returns the scaled linearized index based on the `srcBits` and `dstBits`
126/// sizes. The input `linearizedIndex` has the granularity of `srcBits`, and
127/// the returned index has the granularity of `dstBits`
129 OpFoldResult linearizedIndex,
130 int64_t srcBits, int64_t dstBits) {
131 AffineExpr s0;
132 bindSymbols(builder.getContext(), s0);
133 int64_t scaler = dstBits / srcBits;
134 OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
135 builder, loc, s0.floorDiv(scaler), {linearizedIndex});
136 return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices);
137}
138
139static OpFoldResult
142 Value memref) {
143 auto stridedMetadata =
144 memref::ExtractStridedMetadataOp::create(builder, loc, memref);
145 OpFoldResult linearizedIndices;
146 std::tie(std::ignore, linearizedIndices) =
148 builder, loc, srcBits, srcBits,
149 stridedMetadata.getConstifiedMixedOffset(),
150 stridedMetadata.getConstifiedMixedSizes(),
151 stridedMetadata.getConstifiedMixedStrides(), indices);
152 return linearizedIndices;
153}
154
155namespace {
156
157//===----------------------------------------------------------------------===//
158// ConvertMemRefAllocation
159//===----------------------------------------------------------------------===//
160
161template <typename OpTy>
162struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
163 using OpConversionPattern<OpTy>::OpConversionPattern;
164
165 LogicalResult
166 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
167 ConversionPatternRewriter &rewriter) const override {
168 static_assert(std::is_same<OpTy, memref::AllocOp>() ||
169 std::is_same<OpTy, memref::AllocaOp>(),
170 "expected only memref::AllocOp or memref::AllocaOp");
171 auto currentType = cast<MemRefType>(op.getMemref().getType());
172 auto newResultType =
173 this->getTypeConverter()->template convertType<MemRefType>(
174 op.getType());
175 if (!newResultType) {
176 return rewriter.notifyMatchFailure(
177 op->getLoc(),
178 llvm::formatv("failed to convert memref type: {0}", op.getType()));
179 }
180
181 // Special case zero-rank memrefs.
182 if (currentType.getRank() == 0) {
183 rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
184 adaptor.getSymbolOperands(),
185 adaptor.getAlignmentAttr());
186 return success();
187 }
188
189 Location loc = op.getLoc();
190 OpFoldResult zero = rewriter.getIndexAttr(0);
191
192 // Get linearized type.
193 int srcBits = currentType.getElementType().getIntOrFloatBitWidth();
194 int dstBits = newResultType.getElementType().getIntOrFloatBitWidth();
195 SmallVector<OpFoldResult> sizes = op.getMixedSizes();
196
197 memref::LinearizedMemRefInfo linearizedMemRefInfo =
199 rewriter, loc, srcBits, dstBits, /*offset =*/zero, sizes);
200 SmallVector<Value> dynamicLinearizedSize;
201 if (!newResultType.hasStaticShape()) {
202 dynamicLinearizedSize.push_back(getValueOrCreateConstantIndexOp(
203 rewriter, loc, linearizedMemRefInfo.linearizedSize));
204 }
205
206 rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
207 adaptor.getSymbolOperands(),
208 adaptor.getAlignmentAttr());
209 return success();
210 }
211};
212
213//===----------------------------------------------------------------------===//
214// ConvertMemRefAssumeAlignment
215//===----------------------------------------------------------------------===//
216
217struct ConvertMemRefAssumeAlignment final
218 : OpConversionPattern<memref::AssumeAlignmentOp> {
219 using OpConversionPattern::OpConversionPattern;
220
221 LogicalResult
222 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
223 ConversionPatternRewriter &rewriter) const override {
224 Type newTy = getTypeConverter()->convertType(op.getMemref().getType());
225 if (!newTy) {
226 return rewriter.notifyMatchFailure(
227 op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
228 op.getMemref().getType()));
229 }
230
231 rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
232 op, newTy, adaptor.getMemref(), adaptor.getAlignmentAttr());
233 return success();
234 }
235};
236
237//===----------------------------------------------------------------------===//
238// ConvertMemRefCopy
239//===----------------------------------------------------------------------===//
240
241struct ConvertMemRefCopy final : OpConversionPattern<memref::CopyOp> {
242 using OpConversionPattern::OpConversionPattern;
243
244 LogicalResult
245 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
246 ConversionPatternRewriter &rewriter) const override {
247 auto maybeRankedSource = dyn_cast<MemRefType>(op.getSource().getType());
248 auto maybeRankedDest = dyn_cast<MemRefType>(op.getTarget().getType());
249 if (maybeRankedSource && maybeRankedDest &&
250 maybeRankedSource.getLayout() != maybeRankedDest.getLayout())
251 return rewriter.notifyMatchFailure(
252 op, llvm::formatv("memref.copy emulation with distinct layouts ({0} "
253 "and {1}) is currently unimplemented",
254 maybeRankedSource.getLayout(),
255 maybeRankedDest.getLayout()));
256 rewriter.replaceOpWithNewOp<memref::CopyOp>(op, adaptor.getSource(),
257 adaptor.getTarget());
258 return success();
259 }
260};
261
262//===----------------------------------------------------------------------===//
263// ConvertMemRefDealloc
264//===----------------------------------------------------------------------===//
265
266struct ConvertMemRefDealloc final : OpConversionPattern<memref::DeallocOp> {
267 using OpConversionPattern::OpConversionPattern;
268
269 LogicalResult
270 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
271 ConversionPatternRewriter &rewriter) const override {
272 rewriter.replaceOpWithNewOp<memref::DeallocOp>(op, adaptor.getMemref());
273 return success();
274 }
275};
276
277//===----------------------------------------------------------------------===//
278// ConvertMemRefLoad
279//===----------------------------------------------------------------------===//
280
281struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
282 using OpConversionPattern::OpConversionPattern;
283
284 LogicalResult
285 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
286 ConversionPatternRewriter &rewriter) const override {
287 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
288 auto convertedElementType = convertedType.getElementType();
289 auto oldElementType = op.getMemRefType().getElementType();
290 int srcBits = oldElementType.getIntOrFloatBitWidth();
291 int dstBits = convertedElementType.getIntOrFloatBitWidth();
292 if (dstBits % srcBits != 0) {
293 return rewriter.notifyMatchFailure(
294 op, "only dstBits % srcBits == 0 supported");
295 }
296
297 Location loc = op.getLoc();
298 // Special case 0-rank memref loads.
299 Value bitsLoad;
300 if (convertedType.getRank() == 0) {
301 bitsLoad = memref::LoadOp::create(rewriter, loc, adaptor.getMemref(),
302 ValueRange{});
303 } else {
304 // Linearize the indices of the original load instruction. Do not account
305 // for the scaling yet. This will be accounted for later.
306 OpFoldResult linearizedIndices = getLinearizedSrcIndices(
307 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
308
309 Value newLoad = memref::LoadOp::create(
310 rewriter, loc, adaptor.getMemref(),
311 getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
312 dstBits));
313
314 // Get the offset and shift the bits to the rightmost.
315 // Note, currently only the big-endian is supported.
316 Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices,
317 srcBits, dstBits, rewriter);
318 bitsLoad = arith::ShRSIOp::create(rewriter, loc, newLoad, bitwidthOffset);
319 }
320
321 // Get the corresponding bits. If the arith computation bitwidth equals
322 // to the emulated bitwidth, we apply a mask to extract the low bits.
323 // It is not clear if this case actually happens in practice, but we keep
324 // the operations just in case. Otherwise, if the arith computation bitwidth
325 // is different from the emulated bitwidth we truncate the result.
326 Value result;
327 auto resultTy = getTypeConverter()->convertType(oldElementType);
328 auto conversionTy =
329 resultTy.isInteger()
330 ? resultTy
331 : IntegerType::get(rewriter.getContext(),
332 resultTy.getIntOrFloatBitWidth());
333 if (conversionTy == convertedElementType) {
334 auto mask = arith::ConstantOp::create(
335 rewriter, loc, convertedElementType,
336 rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
337
338 result = arith::AndIOp::create(rewriter, loc, bitsLoad, mask);
339 } else {
340 result = arith::TruncIOp::create(rewriter, loc, conversionTy, bitsLoad);
341 }
342
343 if (conversionTy != resultTy) {
344 result = arith::BitcastOp::create(rewriter, loc, resultTy, result);
345 }
346
347 rewriter.replaceOp(op, result);
348 return success();
349 }
350};
351
352//===----------------------------------------------------------------------===//
353// ConvertMemRefMemorySpaceCast
354//===----------------------------------------------------------------------===//
355
356struct ConvertMemRefMemorySpaceCast final
357 : OpConversionPattern<memref::MemorySpaceCastOp> {
358 using OpConversionPattern::OpConversionPattern;
359
360 LogicalResult
361 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
362 ConversionPatternRewriter &rewriter) const override {
363 Type newTy = getTypeConverter()->convertType(op.getDest().getType());
364 if (!newTy) {
365 return rewriter.notifyMatchFailure(
366 op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
367 op.getDest().getType()));
368 }
369
370 rewriter.replaceOpWithNewOp<memref::MemorySpaceCastOp>(op, newTy,
371 adaptor.getSource());
372 return success();
373 }
374};
375
376//===----------------------------------------------------------------------===//
377// ConvertMemRefReinterpretCast
378//===----------------------------------------------------------------------===//
379
380/// Output types should be at most one dimensional, so only the 0 or 1
381/// dimensional cases are supported.
382struct ConvertMemRefReinterpretCast final
383 : OpConversionPattern<memref::ReinterpretCastOp> {
384 using OpConversionPattern::OpConversionPattern;
385
386 LogicalResult
387 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
388 ConversionPatternRewriter &rewriter) const override {
389 MemRefType newTy =
390 getTypeConverter()->convertType<MemRefType>(op.getType());
391 if (!newTy) {
392 return rewriter.notifyMatchFailure(
393 op->getLoc(),
394 llvm::formatv("failed to convert memref type: {0}", op.getType()));
395 }
396
397 // Only support for 0 or 1 dimensional cases.
398 if (op.getType().getRank() > 1) {
399 return rewriter.notifyMatchFailure(
400 op->getLoc(), "subview with rank > 1 is not supported");
401 }
402
403 return convertCastingOp(rewriter, adaptor, op, newTy);
404 }
405};
406
407//===----------------------------------------------------------------------===//
408// ConvertMemrefStore
409//===----------------------------------------------------------------------===//
410
411/// Emulate narrow type memref store with a non-atomic or atomic
412/// read-modify-write sequence. The `disableAtomicRMW` indicates whether to use
413/// a normal read-modify-write sequence instead of using
414/// `memref.generic_atomic_rmw` to perform subbyte storing.
415struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
416 using OpConversionPattern::OpConversionPattern;
417
418 ConvertMemrefStore(MLIRContext *context, bool disableAtomicRMW)
419 : OpConversionPattern<memref::StoreOp>(context),
420 disableAtomicRMW(disableAtomicRMW) {}
421
422 LogicalResult
423 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
424 ConversionPatternRewriter &rewriter) const override {
425 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
426 int srcBits = op.getMemRefType().getElementTypeBitWidth();
427 int dstBits = convertedType.getElementTypeBitWidth();
428 auto dstIntegerType = rewriter.getIntegerType(dstBits);
429 if (dstBits % srcBits != 0) {
430 return rewriter.notifyMatchFailure(
431 op, "only dstBits % srcBits == 0 supported");
432 }
433
434 Location loc = op.getLoc();
435
436 // Pad the input value with 0s on the left.
437 Value input = adaptor.getValue();
438 if (!input.getType().isInteger()) {
439 input = arith::BitcastOp::create(
440 rewriter, loc,
441 IntegerType::get(rewriter.getContext(),
443 input);
444 }
445 Value extendedInput =
446 arith::ExtUIOp::create(rewriter, loc, dstIntegerType, input);
447
448 // Special case 0-rank memref stores. No need for masking. The non-atomic
449 // store is used because it operates on the entire value.
450 if (convertedType.getRank() == 0) {
451 memref::StoreOp::create(rewriter, loc, extendedInput, adaptor.getMemref(),
452 ValueRange{});
453 rewriter.eraseOp(op);
454 return success();
455 }
456
457 OpFoldResult linearizedIndices = getLinearizedSrcIndices(
458 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
459 Value storeIndices = getIndicesForLoadOrStore(
460 rewriter, loc, linearizedIndices, srcBits, dstBits);
461 Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits,
462 dstBits, rewriter);
463 Value writeMask = getSubByteWriteMask(loc, linearizedIndices, srcBits,
464 dstBits, bitwidthOffset, rewriter);
465 // Align the value to write with the destination bits.
466 Value alignedVal =
467 arith::ShLIOp::create(rewriter, loc, extendedInput, bitwidthOffset);
468
469 if (disableAtomicRMW) {
470 // Load the original value.
471 Value origValue = memref::LoadOp::create(
472 rewriter, loc, adaptor.getMemref(), storeIndices);
473 // Clear destination bits (and with mask).
474 Value clearedValue =
475 arith::AndIOp::create(rewriter, loc, origValue, writeMask);
476 // Write src bits to destination (or with aligned value), and store the
477 // result.
478 Value newValue =
479 arith::OrIOp::create(rewriter, loc, clearedValue, alignedVal);
480 memref::StoreOp::create(rewriter, loc, newValue, adaptor.getMemref(),
481 storeIndices);
482 } else {
483 // Atomic read-modify-write operations.
484 // Clear destination bits.
485 memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::andi,
486 writeMask, adaptor.getMemref(), storeIndices);
487 // Write src bits to destination.
488 memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::ori,
489 alignedVal, adaptor.getMemref(),
490 storeIndices);
491 }
492 rewriter.eraseOp(op);
493 return success();
494 }
495
496private:
497 bool disableAtomicRMW;
498};
499
500//===----------------------------------------------------------------------===//
501// ConvertMemRefSubview
502//===----------------------------------------------------------------------===//
503
504/// Emulating narrow ints on subview have limited support, supporting only
505/// static offset and size and stride of 1. Ideally, the subview should be
506/// folded away before running narrow type emulation, and this pattern should
507/// only run for cases that can't be folded.
508struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
509 using OpConversionPattern::OpConversionPattern;
510
511 LogicalResult
512 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
513 ConversionPatternRewriter &rewriter) const override {
514 MemRefType newTy =
515 getTypeConverter()->convertType<MemRefType>(subViewOp.getType());
516 if (!newTy) {
517 return rewriter.notifyMatchFailure(
518 subViewOp->getLoc(),
519 llvm::formatv("failed to convert memref type: {0}",
520 subViewOp.getType()));
521 }
522
523 Location loc = subViewOp.getLoc();
524 Type convertedElementType = newTy.getElementType();
525 Type oldElementType = subViewOp.getType().getElementType();
526 int srcBits = oldElementType.getIntOrFloatBitWidth();
527 int dstBits = convertedElementType.getIntOrFloatBitWidth();
528 if (dstBits % srcBits != 0)
529 return rewriter.notifyMatchFailure(
530 subViewOp, "only dstBits % srcBits == 0 supported");
531
532 // Only support stride of 1.
533 if (llvm::any_of(subViewOp.getStaticStrides(),
534 [](int64_t stride) { return stride != 1; })) {
535 return rewriter.notifyMatchFailure(subViewOp->getLoc(),
536 "stride != 1 is not supported");
537 }
538
539 if (!memref::isStaticShapeAndContiguousRowMajor(subViewOp.getType())) {
540 return rewriter.notifyMatchFailure(
541 subViewOp, "the result memref type is not contiguous");
542 }
543
544 auto sizes = subViewOp.getStaticSizes();
545 int64_t lastOffset = subViewOp.getStaticOffsets().back();
546 // Only support static sizes and offsets.
547 if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
548 lastOffset == ShapedType::kDynamic) {
549 return rewriter.notifyMatchFailure(
550 subViewOp->getLoc(), "dynamic size or offset is not supported");
551 }
552
553 // Transform the offsets, sizes and strides according to the emulation.
554 auto stridedMetadata = memref::ExtractStridedMetadataOp::create(
555 rewriter, loc, subViewOp.getViewSource());
556
557 OpFoldResult linearizedIndices;
558 auto strides = stridedMetadata.getConstifiedMixedStrides();
559 memref::LinearizedMemRefInfo linearizedInfo;
560 std::tie(linearizedInfo, linearizedIndices) =
562 rewriter, loc, srcBits, dstBits,
563 stridedMetadata.getConstifiedMixedOffset(),
564 subViewOp.getMixedSizes(), strides,
565 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(),
566 rewriter));
567
568 rewriter.replaceOpWithNewOp<memref::SubViewOp>(
569 subViewOp, newTy, adaptor.getSource(), linearizedIndices,
570 linearizedInfo.linearizedSize, strides.back());
571 return success();
572 }
573};
574
575//===----------------------------------------------------------------------===//
576// ConvertMemRefCollapseShape
577//===----------------------------------------------------------------------===//
578
579/// Emulating a `memref.collapse_shape` becomes a no-op after emulation given
580/// that we flatten memrefs to a single dimension as part of the emulation and
581/// there is no dimension to collapse any further.
582struct ConvertMemRefCollapseShape final
583 : OpConversionPattern<memref::CollapseShapeOp> {
584 using OpConversionPattern::OpConversionPattern;
585
586 LogicalResult
587 matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
588 ConversionPatternRewriter &rewriter) const override {
589 Value srcVal = adaptor.getSrc();
590 auto newTy = dyn_cast<MemRefType>(srcVal.getType());
591 if (!newTy)
592 return failure();
593
594 if (newTy.getRank() != 1)
595 return failure();
596
597 rewriter.replaceOp(collapseShapeOp, srcVal);
598 return success();
599 }
600};
601
602/// Emulating a `memref.expand_shape` becomes a no-op after emulation given
603/// that we flatten memrefs to a single dimension as part of the emulation and
604/// the expansion would just have been undone.
605struct ConvertMemRefExpandShape final
606 : OpConversionPattern<memref::ExpandShapeOp> {
607 using OpConversionPattern::OpConversionPattern;
608
609 LogicalResult
610 matchAndRewrite(memref::ExpandShapeOp expandShapeOp, OpAdaptor adaptor,
611 ConversionPatternRewriter &rewriter) const override {
612 Value srcVal = adaptor.getSrc();
613 auto newTy = dyn_cast<MemRefType>(srcVal.getType());
614 if (!newTy)
615 return failure();
616
617 if (newTy.getRank() != 1)
618 return failure();
619
620 rewriter.replaceOp(expandShapeOp, srcVal);
621 return success();
622 }
623};
624} // end anonymous namespace
625
626//===----------------------------------------------------------------------===//
627// Public Interface Definition
628//===----------------------------------------------------------------------===//
629
631 const arith::NarrowTypeEmulationConverter &typeConverter,
632 RewritePatternSet &patterns, bool disableAtomicRMW) {
633
634 // Populate `memref.*` conversion patterns.
635 patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
636 ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCopy,
637 ConvertMemRefDealloc, ConvertMemRefCollapseShape,
638 ConvertMemRefExpandShape, ConvertMemRefLoad,
639 ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast,
640 ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
641 typeConverter, patterns.getContext());
642 patterns.insert<ConvertMemrefStore>(patterns.getContext(), disableAtomicRMW);
644}
645
646static SmallVector<int64_t> getLinearizedShape(MemRefType ty, int srcBits,
647 int dstBits) {
648 if (ty.getRank() == 0)
649 return {};
650
651 int64_t linearizedShape = 1;
652 for (auto shape : ty.getShape()) {
653 if (shape == ShapedType::kDynamic)
654 return {ShapedType::kDynamic};
655 linearizedShape *= shape;
656 }
657 int scale = dstBits / srcBits;
658 // Scale the size to the ceilDiv(linearizedShape, scale)
659 // to accomodate all the values.
660 linearizedShape = (linearizedShape + scale - 1) / scale;
661 return {linearizedShape};
662}
663
666 typeConverter.addConversion(
667 [&typeConverter](MemRefType ty) -> std::optional<Type> {
668 Type elementType = ty.getElementType();
669 if (!elementType.isIntOrFloat())
670 return ty;
671
672 unsigned width = elementType.getIntOrFloatBitWidth();
673 unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth();
674 if (width >= loadStoreWidth)
675 return ty;
676
677 // Currently only handle innermost stride being 1, checking
678 SmallVector<int64_t> strides;
679 int64_t offset;
680 if (failed(ty.getStridesAndOffset(strides, offset)))
681 return nullptr;
682 if (!strides.empty() && strides.back() != 1)
683 return nullptr;
684
685 auto newElemTy = IntegerType::get(
686 ty.getContext(), loadStoreWidth,
687 elementType.isInteger()
688 ? cast<IntegerType>(elementType).getSignedness()
689 : IntegerType::SignednessSemantics::Signless);
690 if (!newElemTy)
691 return nullptr;
692
693 StridedLayoutAttr layoutAttr;
694 // If the offset is 0, we do not need a strided layout as the stride is
695 // 1, so we only use the strided layout if the offset is not 0.
696 if (offset != 0) {
697 if (offset == ShapedType::kDynamic) {
698 layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
700 } else {
701 // Check if the number of bytes are a multiple of the loadStoreWidth
702 // and if so, divide it by the loadStoreWidth to get the offset.
703 if ((offset * width) % loadStoreWidth != 0)
704 return std::nullopt;
705 offset = (offset * width) / loadStoreWidth;
706
707 layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
709 }
710 }
711
712 return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth),
713 newElemTy, layoutAttr, ty.getMemorySpace());
714 });
715}
return success()
static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
When data is loaded/stored in targetBits granularity, but is used in sourceBits granularity (sourceBi...
static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc, OpFoldResult linearizedIndex, int64_t srcBits, int64_t dstBits)
Returns the scaled linearized index based on the srcBits and dstBits sizes.
static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter, memref::ReinterpretCastOp::Adaptor adaptor, memref::ReinterpretCastOp op, MemRefType newTy)
Converts a memref::ReinterpretCastOp to the converted type.
static OpFoldResult getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits, const SmallVector< OpFoldResult > &indices, Value memref)
static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices, int64_t srcBits, int64_t dstBits, Value bitwidthOffset, OpBuilder &builder)
When writing a subbyte size, masked bitwise operations are used to only modify the relevant bits.
static SmallVector< int64_t > getLinearizedShape(MemRefType ty, int srcBits, int dstBits)
Base type for affine expression.
Definition AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
MLIRContext * getContext() const
Definition Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
This class represents a single result from folding an operation.
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Converts narrow integer or float types that are not supported by the target hardware to wider types.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateMemRefNarrowTypeEmulationPatterns(const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns, bool disableAtomicRMW=false)
Appends patterns for emulating memref operations over narrow types with ops over wider types.
void populateMemRefNarrowTypeEmulationConversions(arith::NarrowTypeEmulationConverter &typeConverter)
Appends type conversions for emulating memref operations over narrow types with ops over wider types.
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112