MLIR 23.0.0git
EmulateNarrowType.cpp
Go to the documentation of this file.
1//===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++
2//-*-===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9
19#include "mlir/IR/Builders.h"
23#include "llvm/Support/FormatVariadic.h"
24#include "llvm/Support/MathExtras.h"
25#include <cassert>
26#include <type_traits>
27
28using namespace mlir;
29
30//===----------------------------------------------------------------------===//
31// Utility functions
32//===----------------------------------------------------------------------===//
33
34/// Converts a memref::ReinterpretCastOp to the converted type. The result
35/// memref is linearized to a rank-1 byte view (or rank-0 if the source is
36/// rank-0). When `assumeAligned` is true, dynamic offsets are accepted under
37/// the alignment contract that the caller guarantees the offset is a multiple
38/// of `dstBits / srcBits`; statically-provable misalignment is rejected.
39/// When `assumeAligned` is false, dynamic offsets are rejected outright since
40/// divisibility cannot be proven from the IR alone.
41static LogicalResult
42convertCastingOp(ConversionPatternRewriter &rewriter,
43 memref::ReinterpretCastOp::Adaptor adaptor,
44 memref::ReinterpretCastOp op, MemRefType newTy,
45 bool assumeAligned) {
46 if (newTy == op.getType()) {
47 return rewriter.notifyMatchFailure(
48 op, "result type was not converted by narrow-type emulation");
49 }
50
51 Type convertedElementType = newTy.getElementType();
52 Type oldElementType = op.getType().getElementType();
53 int srcBits = oldElementType.getIntOrFloatBitWidth();
54 int dstBits = convertedElementType.getIntOrFloatBitWidth();
55 if (dstBits % srcBits != 0) {
56 return rewriter.notifyMatchFailure(op,
57 "only dstBits % srcBits == 0 supported");
58 }
59
60 ArrayRef<int64_t> staticStrides = op.getStaticStrides();
61 if (!staticStrides.empty() && staticStrides.back() != 1) {
62 return rewriter.notifyMatchFailure(
63 op->getLoc(), "innermost stride != 1 is not supported");
64 }
65
66 // TODO: support dynamic sizes. Requires a divisibility analysis or a
67 // stronger alignment contract; tracked as follow-up work.
68 if (llvm::is_contained(op.getStaticSizes(), ShapedType::kDynamic)) {
69 return rewriter.notifyMatchFailure(op, "dynamic sizes are not supported");
70 }
71
73 return rewriter.notifyMatchFailure(
74 op, "result memref is not row-major contiguous");
75 }
76
77 // Reject dynamic offsets unless the caller has opted into the alignment
78 // contract via `assumeAligned`. Without it we cannot prove the offset is a
79 // multiple of `dstBits / srcBits`.
80 if (!assumeAligned &&
81 llvm::is_contained(op.getStaticOffsets(), ShapedType::kDynamic)) {
82 return rewriter.notifyMatchFailure(
83 op, "dynamic offsets require assumeAligned=true to ensure the offset "
84 "is a multiple of dstBits / srcBits");
85 }
86
87 Location loc = op.getLoc();
88 SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
89 OpFoldResult origOffset = op.getMixedOffsets()[0];
90
93 OpFoldResult newOffset;
94 OpFoldResult intraOffset;
95 if (mixedSizes.empty()) {
96 int64_t elementsPerByte = dstBits / srcBits;
97 AffineExpr s0;
98 bindSymbols(rewriter.getContext(), s0);
100 rewriter, loc, s0.floorDiv(elementsPerByte), {origOffset});
102 rewriter, loc, s0 % elementsPerByte, {origOffset});
103 } else {
104 // Use ceil division so the produced linearized size matches the converted
105 // result memref shape (see `getLinearizedShape` in the type converter),
106 // which also rounds up to fit all source elements.
109 rewriter, loc, srcBits, dstBits, origOffset, mixedSizes,
111 newOffset = info.linearizedOffset;
112 intraOffset = info.intraDataOffset;
113 newSizes.push_back(info.linearizedSize);
114 newStrides.push_back(rewriter.getIndexAttr(1));
115 }
116
117 if (auto cst = getConstantIntValue(intraOffset); cst && *cst != 0) {
118 return rewriter.notifyMatchFailure(
119 op, "offset is provably not a multiple of dstBits / srcBits");
120 }
121
122 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
123 op, newTy, adaptor.getSource(), newOffset, newSizes, newStrides);
124 return success();
125}
126
127/// When data is loaded/stored in `targetBits` granularity, but is used in
128/// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
129/// treated as an array of elements of width `sourceBits`.
130/// Return the bit offset of the value at position `srcIdx`. For example, if
131/// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
132/// located at (x % 2) * 4. Because there are two elements in one i8, and one
133/// element has 4 bits.
135 int sourceBits, int targetBits,
136 OpBuilder &builder) {
137 assert(targetBits % sourceBits == 0);
138 AffineExpr s0;
139 bindSymbols(builder.getContext(), s0);
140 int scaleFactor = targetBits / sourceBits;
141 AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
142 OpFoldResult offsetVal =
143 affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
144 Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
145 IntegerType dstType = builder.getIntegerType(targetBits);
146 return arith::IndexCastOp::create(builder, loc, dstType, bitOffset);
147}
148
149/// When writing a subbyte size, masked bitwise operations are used to only
150/// modify the relevant bits. This function returns an and mask for clearing
151/// the destination bits in a subbyte write. E.g., when writing to the second
152/// i4 in an i32, 0xFFFFFF0F is created.
153static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices,
154 int64_t srcBits, int64_t dstBits,
155 Value bitwidthOffset, OpBuilder &builder) {
156 auto dstIntegerType = builder.getIntegerType(dstBits);
157 auto maskRightAlignedAttr =
158 builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
159 Value maskRightAligned = arith::ConstantOp::create(
160 builder, loc, dstIntegerType, maskRightAlignedAttr);
161 Value writeMaskInverse =
162 arith::ShLIOp::create(builder, loc, maskRightAligned, bitwidthOffset);
163 auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
164 Value flipVal =
165 arith::ConstantOp::create(builder, loc, dstIntegerType, flipValAttr);
166 return arith::XOrIOp::create(builder, loc, writeMaskInverse, flipVal);
167}
168
169/// Returns the scaled linearized index based on the `srcBits` and `dstBits`
170/// sizes. The input `linearizedIndex` has the granularity of `srcBits`, and
171/// the returned index has the granularity of `dstBits`
173 OpFoldResult linearizedIndex,
174 int64_t srcBits, int64_t dstBits) {
175 AffineExpr s0;
176 bindSymbols(builder.getContext(), s0);
177 int64_t scaler = dstBits / srcBits;
178 OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
179 builder, loc, s0.floorDiv(scaler), {linearizedIndex});
180 return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices);
181}
182
183static OpFoldResult
186 Value memref) {
187 auto stridedMetadata =
188 memref::ExtractStridedMetadataOp::create(builder, loc, memref);
189 OpFoldResult linearizedIndices;
190 std::tie(std::ignore, linearizedIndices) =
192 builder, loc, srcBits, srcBits,
193 stridedMetadata.getConstifiedMixedOffset(),
194 stridedMetadata.getConstifiedMixedSizes(),
195 stridedMetadata.getConstifiedMixedStrides(), indices);
196 return linearizedIndices;
197}
198
199namespace {
200
201//===----------------------------------------------------------------------===//
202// ConvertMemRefAllocation
203//===----------------------------------------------------------------------===//
204
205template <typename OpTy>
206struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
207 using OpConversionPattern<OpTy>::OpConversionPattern;
208
209 LogicalResult
210 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
211 ConversionPatternRewriter &rewriter) const override {
212 static_assert(std::is_same<OpTy, memref::AllocOp>() ||
213 std::is_same<OpTy, memref::AllocaOp>(),
214 "expected only memref::AllocOp or memref::AllocaOp");
215 auto currentType = cast<MemRefType>(op.getMemref().getType());
216 auto newResultType =
217 this->getTypeConverter()->template convertType<MemRefType>(
218 op.getType());
219 if (!newResultType) {
220 return rewriter.notifyMatchFailure(
221 op->getLoc(),
222 llvm::formatv("failed to convert memref type: {0}", op.getType()));
223 }
224
225 // Special case zero-rank memrefs.
226 if (currentType.getRank() == 0) {
227 rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
228 adaptor.getSymbolOperands(),
229 adaptor.getAlignmentAttr());
230 return success();
231 }
232
233 Location loc = op.getLoc();
234 OpFoldResult zero = rewriter.getIndexAttr(0);
235
236 // Get linearized type.
237 int srcBits = currentType.getElementType().getIntOrFloatBitWidth();
238 int dstBits = newResultType.getElementType().getIntOrFloatBitWidth();
239 SmallVector<OpFoldResult> sizes = op.getMixedSizes();
240
241 memref::LinearizedMemRefInfo linearizedMemRefInfo =
243 rewriter, loc, srcBits, dstBits, /*offset =*/zero, sizes);
244 SmallVector<Value> dynamicLinearizedSize;
245 if (!newResultType.hasStaticShape()) {
246 dynamicLinearizedSize.push_back(getValueOrCreateConstantIndexOp(
247 rewriter, loc, linearizedMemRefInfo.linearizedSize));
248 }
249
250 rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
251 adaptor.getSymbolOperands(),
252 adaptor.getAlignmentAttr());
253 return success();
254 }
255};
256
257//===----------------------------------------------------------------------===//
258// ConvertMemRefAssumeAlignment
259//===----------------------------------------------------------------------===//
260
261struct ConvertMemRefAssumeAlignment final
262 : OpConversionPattern<memref::AssumeAlignmentOp> {
263 using OpConversionPattern::OpConversionPattern;
264
265 LogicalResult
266 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
267 ConversionPatternRewriter &rewriter) const override {
268 Type newTy = getTypeConverter()->convertType(op.getMemref().getType());
269 if (!newTy) {
270 return rewriter.notifyMatchFailure(
271 op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
272 op.getMemref().getType()));
273 }
274
275 rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
276 op, newTy, adaptor.getMemref(), adaptor.getAlignmentAttr());
277 return success();
278 }
279};
280
281//===----------------------------------------------------------------------===//
282// ConvertMemRefCopy
283//===----------------------------------------------------------------------===//
284
285struct ConvertMemRefCopy final : OpConversionPattern<memref::CopyOp> {
286 using OpConversionPattern::OpConversionPattern;
287
288 LogicalResult
289 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
290 ConversionPatternRewriter &rewriter) const override {
291 auto maybeRankedSource = dyn_cast<MemRefType>(op.getSource().getType());
292 auto maybeRankedDest = dyn_cast<MemRefType>(op.getTarget().getType());
293 if (maybeRankedSource && maybeRankedDest &&
294 maybeRankedSource.getLayout() != maybeRankedDest.getLayout())
295 return rewriter.notifyMatchFailure(
296 op, llvm::formatv("memref.copy emulation with distinct layouts ({0} "
297 "and {1}) is currently unimplemented",
298 maybeRankedSource.getLayout(),
299 maybeRankedDest.getLayout()));
300 rewriter.replaceOpWithNewOp<memref::CopyOp>(op, adaptor.getSource(),
301 adaptor.getTarget());
302 return success();
303 }
304};
305
306//===----------------------------------------------------------------------===//
307// ConvertMemRefDealloc
308//===----------------------------------------------------------------------===//
309
310struct ConvertMemRefDealloc final : OpConversionPattern<memref::DeallocOp> {
311 using OpConversionPattern::OpConversionPattern;
312
313 LogicalResult
314 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
315 ConversionPatternRewriter &rewriter) const override {
316 rewriter.replaceOpWithNewOp<memref::DeallocOp>(op, adaptor.getMemref());
317 return success();
318 }
319};
320
321//===----------------------------------------------------------------------===//
322// ConvertMemRefLoad
323//===----------------------------------------------------------------------===//
324
325struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
326 using OpConversionPattern::OpConversionPattern;
327
328 LogicalResult
329 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
330 ConversionPatternRewriter &rewriter) const override {
331 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
332 auto convertedElementType = convertedType.getElementType();
333 auto oldElementType = op.getMemRefType().getElementType();
334 int srcBits = oldElementType.getIntOrFloatBitWidth();
335 int dstBits = convertedElementType.getIntOrFloatBitWidth();
336 if (dstBits % srcBits != 0) {
337 return rewriter.notifyMatchFailure(
338 op, "only dstBits % srcBits == 0 supported");
339 }
340
341 Location loc = op.getLoc();
342 // Special case 0-rank memref loads.
343 Value bitsLoad;
344 if (convertedType.getRank() == 0) {
345 bitsLoad = memref::LoadOp::create(rewriter, loc, adaptor.getMemref(),
346 ValueRange{});
347 } else {
348 // Linearize the indices of the original load instruction. Do not account
349 // for the scaling yet. This will be accounted for later.
350 OpFoldResult linearizedIndices = getLinearizedSrcIndices(
351 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
352
353 Value newLoad = memref::LoadOp::create(
354 rewriter, loc, adaptor.getMemref(),
355 getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
356 dstBits));
357
358 // Get the offset and shift the bits to the rightmost.
359 // Note, currently only the big-endian is supported.
360 Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices,
361 srcBits, dstBits, rewriter);
362 bitsLoad = arith::ShRSIOp::create(rewriter, loc, newLoad, bitwidthOffset);
363 }
364
365 // Get the corresponding bits. If the arith computation bitwidth equals
366 // to the emulated bitwidth, we apply a mask to extract the low bits.
367 // It is not clear if this case actually happens in practice, but we keep
368 // the operations just in case. Otherwise, if the arith computation bitwidth
369 // is different from the emulated bitwidth we truncate the result.
370 Value result;
371 auto resultTy = getTypeConverter()->convertType(oldElementType);
372 auto conversionTy =
373 resultTy.isInteger()
374 ? resultTy
375 : IntegerType::get(rewriter.getContext(),
376 resultTy.getIntOrFloatBitWidth());
377 if (conversionTy == convertedElementType) {
378 auto mask = arith::ConstantOp::create(
379 rewriter, loc, convertedElementType,
380 rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
381
382 result = arith::AndIOp::create(rewriter, loc, bitsLoad, mask);
383 } else {
384 result = arith::TruncIOp::create(rewriter, loc, conversionTy, bitsLoad);
385 }
386
387 if (conversionTy != resultTy) {
388 result = arith::BitcastOp::create(rewriter, loc, resultTy, result);
389 }
390
391 rewriter.replaceOp(op, result);
392 return success();
393 }
394};
395
396//===----------------------------------------------------------------------===//
397// ConvertMemRefCast
398//===----------------------------------------------------------------------===//
399
400/// `memref.cast` between two narrow-typed memrefs forwards through the type
401/// converter to a cast between the converted byte-typed memrefs.
402struct ConvertMemRefCast final : OpConversionPattern<memref::CastOp> {
403 using OpConversionPattern::OpConversionPattern;
404
405 LogicalResult
406 matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
407 ConversionPatternRewriter &rewriter) const override {
408 Type newTy = getTypeConverter()->convertType(op.getType());
409 if (!newTy) {
410 return rewriter.notifyMatchFailure(
411 op->getLoc(),
412 llvm::formatv("failed to convert memref type: {0}", op.getType()));
413 }
414 if (newTy == op.getType())
415 return failure();
416
417 rewriter.replaceOpWithNewOp<memref::CastOp>(op, newTy, adaptor.getSource());
418 return success();
419 }
420};
421
422//===----------------------------------------------------------------------===//
423// ConvertMemRefMemorySpaceCast
424//===----------------------------------------------------------------------===//
425
426struct ConvertMemRefMemorySpaceCast final
427 : OpConversionPattern<memref::MemorySpaceCastOp> {
428 using OpConversionPattern::OpConversionPattern;
429
430 LogicalResult
431 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
432 ConversionPatternRewriter &rewriter) const override {
433 Type newTy = getTypeConverter()->convertType(op.getDest().getType());
434 if (!newTy) {
435 return rewriter.notifyMatchFailure(
436 op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
437 op.getDest().getType()));
438 }
439
440 rewriter.replaceOpWithNewOp<memref::MemorySpaceCastOp>(op, newTy,
441 adaptor.getSource());
442 return success();
443 }
444};
445
446//===----------------------------------------------------------------------===//
447// ConvertMemRefReinterpretCast
448//===----------------------------------------------------------------------===//
449
450/// Forwards to `convertCastingOp`, which enforces all preconditions.
451/// `assumeAligned` is propagated from the populate entry point and controls
452/// acceptance of dynamic offsets.
453struct ConvertMemRefReinterpretCast final
454 : OpConversionPattern<memref::ReinterpretCastOp> {
455 ConvertMemRefReinterpretCast(const TypeConverter &typeConverter,
456 MLIRContext *context, bool assumeAligned)
457 : OpConversionPattern<memref::ReinterpretCastOp>(typeConverter, context),
458 assumeAligned(assumeAligned) {}
459
460 LogicalResult
461 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
462 ConversionPatternRewriter &rewriter) const override {
463 MemRefType newTy =
464 getTypeConverter()->convertType<MemRefType>(op.getType());
465 if (!newTy) {
466 return rewriter.notifyMatchFailure(
467 op->getLoc(),
468 llvm::formatv("failed to convert memref type: {0}", op.getType()));
469 }
470
471 return convertCastingOp(rewriter, adaptor, op, newTy, assumeAligned);
472 }
473
474private:
475 bool assumeAligned;
476};
477
478//===----------------------------------------------------------------------===//
479// ConvertMemrefStore
480//===----------------------------------------------------------------------===//
481
482/// Emulate narrow type memref store with a non-atomic or atomic
483/// read-modify-write sequence. The `disableAtomicRMW` indicates whether to use
484/// a normal read-modify-write sequence instead of using
485/// `memref.generic_atomic_rmw` to perform subbyte storing.
486struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
487 using OpConversionPattern::OpConversionPattern;
488
489 ConvertMemrefStore(const TypeConverter &typeConverter, MLIRContext *context,
490 bool disableAtomicRMW)
491 : OpConversionPattern<memref::StoreOp>(typeConverter, context),
492 disableAtomicRMW(disableAtomicRMW) {}
493
494 LogicalResult
495 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
496 ConversionPatternRewriter &rewriter) const override {
497 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
498 int srcBits = op.getMemRefType().getElementTypeBitWidth();
499 int dstBits = convertedType.getElementTypeBitWidth();
500 auto dstIntegerType = rewriter.getIntegerType(dstBits);
501 if (dstBits % srcBits != 0) {
502 return rewriter.notifyMatchFailure(
503 op, "only dstBits % srcBits == 0 supported");
504 }
505
506 Location loc = op.getLoc();
507
508 // Pad the input value with 0s on the left.
509 Value input = adaptor.getValue();
510 if (!input.getType().isInteger()) {
511 input = arith::BitcastOp::create(
512 rewriter, loc,
513 IntegerType::get(rewriter.getContext(),
515 input);
516 }
517 Value extendedInput =
518 arith::ExtUIOp::create(rewriter, loc, dstIntegerType, input);
519
520 // Special case 0-rank memref stores. No need for masking. The non-atomic
521 // store is used because it operates on the entire value.
522 if (convertedType.getRank() == 0) {
523 memref::StoreOp::create(rewriter, loc, extendedInput, adaptor.getMemref(),
524 ValueRange{});
525 rewriter.eraseOp(op);
526 return success();
527 }
528
529 OpFoldResult linearizedIndices = getLinearizedSrcIndices(
530 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
531 Value storeIndices = getIndicesForLoadOrStore(
532 rewriter, loc, linearizedIndices, srcBits, dstBits);
533 Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits,
534 dstBits, rewriter);
535 Value writeMask = getSubByteWriteMask(loc, linearizedIndices, srcBits,
536 dstBits, bitwidthOffset, rewriter);
537 // Align the value to write with the destination bits.
538 Value alignedVal =
539 arith::ShLIOp::create(rewriter, loc, extendedInput, bitwidthOffset);
540
541 if (disableAtomicRMW) {
542 // Load the original value.
543 Value origValue = memref::LoadOp::create(
544 rewriter, loc, adaptor.getMemref(), storeIndices);
545 // Clear destination bits (and with mask).
546 Value clearedValue =
547 arith::AndIOp::create(rewriter, loc, origValue, writeMask);
548 // Write src bits to destination (or with aligned value), and store the
549 // result.
550 Value newValue =
551 arith::OrIOp::create(rewriter, loc, clearedValue, alignedVal);
552 memref::StoreOp::create(rewriter, loc, newValue, adaptor.getMemref(),
553 storeIndices);
554 } else {
555 // Atomic read-modify-write operations.
556 // Clear destination bits.
557 memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::andi,
558 writeMask, adaptor.getMemref(), storeIndices);
559 // Write src bits to destination.
560 memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::ori,
561 alignedVal, adaptor.getMemref(),
562 storeIndices);
563 }
564 rewriter.eraseOp(op);
565 return success();
566 }
567
568private:
569 bool disableAtomicRMW;
570};
571
572//===----------------------------------------------------------------------===//
573// ConvertMemRefSubview
574//===----------------------------------------------------------------------===//
575
576/// Emulating narrow ints on subview have limited support, supporting only
577/// static sizes and stride of 1. When `assumeAligned` is true, dynamic
578/// offsets are accepted under the alignment contract that the caller
579/// guarantees the offset is a multiple of `dstBits / srcBits`. Without that
580/// opt-in, dynamic offsets are rejected. Ideally, the subview should be
581/// folded away before running narrow type emulation, and this pattern should
582/// only run for cases that can't be folded.
583struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
584 ConvertMemRefSubview(const TypeConverter &typeConverter, MLIRContext *context,
585 bool assumeAligned)
586 : OpConversionPattern<memref::SubViewOp>(typeConverter, context),
587 assumeAligned(assumeAligned) {}
588
589 LogicalResult
590 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
591 ConversionPatternRewriter &rewriter) const override {
592 MemRefType newTy =
593 getTypeConverter()->convertType<MemRefType>(subViewOp.getType());
594 if (!newTy) {
595 return rewriter.notifyMatchFailure(
596 subViewOp->getLoc(),
597 llvm::formatv("failed to convert memref type: {0}",
598 subViewOp.getType()));
599 }
600
601 Location loc = subViewOp.getLoc();
602 Type convertedElementType = newTy.getElementType();
603 Type oldElementType = subViewOp.getType().getElementType();
604 int srcBits = oldElementType.getIntOrFloatBitWidth();
605 int dstBits = convertedElementType.getIntOrFloatBitWidth();
606 if (dstBits % srcBits != 0)
607 return rewriter.notifyMatchFailure(
608 subViewOp, "only dstBits % srcBits == 0 supported");
609
610 // Only support stride of 1.
611 if (llvm::any_of(subViewOp.getStaticStrides(),
612 [](int64_t stride) { return stride != 1; })) {
613 return rewriter.notifyMatchFailure(subViewOp->getLoc(),
614 "stride != 1 is not supported");
615 }
616
617 if (!memref::isStaticShapeAndContiguousRowMajor(subViewOp.getType())) {
618 return rewriter.notifyMatchFailure(
619 subViewOp, "the result memref type is not contiguous");
620 }
621
622 auto sizes = subViewOp.getStaticSizes();
623 // TODO: support dynamic sizes. Requires a divisibility analysis or a
624 // stronger alignment contract; tracked as follow-up work.
625 if (llvm::is_contained(sizes, ShapedType::kDynamic)) {
626 return rewriter.notifyMatchFailure(subViewOp->getLoc(),
627 "dynamic size is not supported");
628 }
629
630 // Reject dynamic offsets unless the caller has opted into the alignment
631 // contract via `assumeAligned`.
632 if (!assumeAligned && llvm::is_contained(subViewOp.getStaticOffsets(),
633 ShapedType::kDynamic)) {
634 return rewriter.notifyMatchFailure(
635 subViewOp,
636 "dynamic offsets require assumeAligned=true to ensure the offset "
637 "is a multiple of dstBits / srcBits");
638 }
639
640 // Transform the offsets, sizes and strides according to the emulation.
641 auto stridedMetadata = memref::ExtractStridedMetadataOp::create(
642 rewriter, loc, subViewOp.getViewSource());
643
644 OpFoldResult linearizedIndices;
645 auto strides = stridedMetadata.getConstifiedMixedStrides();
646 memref::LinearizedMemRefInfo linearizedInfo;
647 std::tie(linearizedInfo, linearizedIndices) =
649 rewriter, loc, srcBits, dstBits,
650 stridedMetadata.getConstifiedMixedOffset(),
651 subViewOp.getMixedSizes(), strides,
652 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(),
653 rewriter));
654
655 if (auto cst = getConstantIntValue(linearizedInfo.intraDataOffset);
656 cst && *cst != 0) {
657 return rewriter.notifyMatchFailure(
658 subViewOp,
659 "subview offset is provably not a multiple of dstBits / srcBits");
660 }
661
662 rewriter.replaceOpWithNewOp<memref::SubViewOp>(
663 subViewOp, newTy, adaptor.getSource(), linearizedIndices,
664 linearizedInfo.linearizedSize, strides.back());
665 return success();
666 }
667
668private:
669 bool assumeAligned;
670};
671
672//===----------------------------------------------------------------------===//
673// ConvertMemRefCollapseShape
674//===----------------------------------------------------------------------===//
675
676/// Emulating a `memref.collapse_shape` becomes a no-op after emulation given
677/// that we flatten memrefs to a single dimension as part of the emulation and
678/// there is no dimension to collapse any further.
679struct ConvertMemRefCollapseShape final
680 : OpConversionPattern<memref::CollapseShapeOp> {
681 using OpConversionPattern::OpConversionPattern;
682
683 LogicalResult
684 matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
685 ConversionPatternRewriter &rewriter) const override {
686 Value srcVal = adaptor.getSrc();
687 auto newTy = dyn_cast<MemRefType>(srcVal.getType());
688 if (!newTy)
689 return failure();
690
691 if (newTy.getRank() != 1)
692 return failure();
693
694 rewriter.replaceOp(collapseShapeOp, srcVal);
695 return success();
696 }
697};
698
699/// Emulating a `memref.expand_shape` becomes a no-op after emulation given
700/// that we flatten memrefs to a single dimension as part of the emulation and
701/// the expansion would just have been undone.
702struct ConvertMemRefExpandShape final
703 : OpConversionPattern<memref::ExpandShapeOp> {
704 using OpConversionPattern::OpConversionPattern;
705
706 LogicalResult
707 matchAndRewrite(memref::ExpandShapeOp expandShapeOp, OpAdaptor adaptor,
708 ConversionPatternRewriter &rewriter) const override {
709 Value srcVal = adaptor.getSrc();
710 auto newTy = dyn_cast<MemRefType>(srcVal.getType());
711 if (!newTy)
712 return failure();
713
714 if (newTy.getRank() != 1)
715 return failure();
716
717 rewriter.replaceOp(expandShapeOp, srcVal);
718 return success();
719 }
720};
721} // end anonymous namespace
722
723//===----------------------------------------------------------------------===//
724// Public Interface Definition
725//===----------------------------------------------------------------------===//
726
728 const arith::NarrowTypeEmulationConverter &typeConverter,
729 RewritePatternSet &patterns, bool disableAtomicRMW, bool assumeAligned) {
730
731 // Populate `memref.*` conversion patterns.
732 patterns
733 .add<ConvertMemRefAllocation<memref::AllocOp>,
734 ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCast,
735 ConvertMemRefCopy, ConvertMemRefDealloc, ConvertMemRefCollapseShape,
736 ConvertMemRefExpandShape, ConvertMemRefLoad,
737 ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast>(
738 typeConverter, patterns.getContext());
739 patterns.add<ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
740 typeConverter, patterns.getContext(), assumeAligned);
741 patterns.insert<ConvertMemrefStore>(typeConverter, patterns.getContext(),
742 disableAtomicRMW);
744}
745
746static SmallVector<int64_t> getLinearizedShape(MemRefType ty, int srcBits,
747 int dstBits) {
748 if (ty.getRank() == 0)
749 return {};
750
751 int64_t linearizedShape = 1;
752 for (auto shape : ty.getShape()) {
753 if (shape == ShapedType::kDynamic)
754 return {ShapedType::kDynamic};
755 linearizedShape *= shape;
756 }
757 int scale = dstBits / srcBits;
758 // Scale the size to the ceilDiv(linearizedShape, scale)
759 // to accomodate all the values.
760 linearizedShape = (linearizedShape + scale - 1) / scale;
761 return {linearizedShape};
762}
763
766 typeConverter.addConversion(
767 [&typeConverter](MemRefType ty) -> std::optional<Type> {
768 Type elementType = ty.getElementType();
769 if (!elementType.isIntOrFloat())
770 return ty;
771
772 unsigned width = elementType.getIntOrFloatBitWidth();
773 unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth();
774 if (width >= loadStoreWidth)
775 return ty;
776
777 // Currently only handle innermost stride being 1, checking
778 SmallVector<int64_t> strides;
779 int64_t offset;
780 if (failed(ty.getStridesAndOffset(strides, offset)))
781 return nullptr;
782 if (!strides.empty() && strides.back() != 1)
783 return nullptr;
784
785 auto newElemTy = IntegerType::get(
786 ty.getContext(), loadStoreWidth,
787 elementType.isInteger()
788 ? cast<IntegerType>(elementType).getSignedness()
789 : IntegerType::SignednessSemantics::Signless);
790 if (!newElemTy)
791 return nullptr;
792
793 StridedLayoutAttr layoutAttr;
794 // If the offset is 0, we do not need a strided layout as the stride is
795 // 1, so we only use the strided layout if the offset is not 0.
796 if (offset != 0) {
797 if (offset == ShapedType::kDynamic) {
798 layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
800 } else {
801 // Check if the number of bytes are a multiple of the loadStoreWidth
802 // and if so, divide it by the loadStoreWidth to get the offset.
803 if ((offset * width) % loadStoreWidth != 0)
804 return std::nullopt;
805 offset = (offset * width) / loadStoreWidth;
806
807 layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
809 }
810 }
811
812 return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth),
813 newElemTy, layoutAttr, ty.getMemorySpace());
814 });
815}
return success()
static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
When data is loaded/stored in targetBits granularity, but is used in sourceBits granularity (sourceBi...
static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc, OpFoldResult linearizedIndex, int64_t srcBits, int64_t dstBits)
Returns the scaled linearized index based on the srcBits and dstBits sizes.
static OpFoldResult getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits, const SmallVector< OpFoldResult > &indices, Value memref)
static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices, int64_t srcBits, int64_t dstBits, Value bitwidthOffset, OpBuilder &builder)
When writing a subbyte size, masked bitwise operations are used to only modify the relevant bits.
static SmallVector< int64_t > getLinearizedShape(MemRefType ty, int srcBits, int dstBits)
static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter, memref::ReinterpretCastOp::Adaptor adaptor, memref::ReinterpretCastOp op, MemRefType newTy, bool assumeAligned)
Converts a memref::ReinterpretCastOp to the converted type.
Base type for affine expression.
Definition AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
MLIRContext * getContext() const
Definition Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
This class represents a single result from folding an operation.
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Converts narrow integer or float types that are not supported by the target hardware to wider types.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={}, LinearizedDivKind sizeDivKind=LinearizedDivKind::Floor)
void populateMemRefNarrowTypeEmulationPatterns(const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns, bool disableAtomicRMW=false, bool assumeAligned=false)
Appends patterns for emulating memref operations over narrow types with ops over wider types.
void populateMemRefNarrowTypeEmulationConversions(arith::NarrowTypeEmulationConverter &typeConverter)
Appends type conversions for emulating memref operations over narrow types with ops over wider types.
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
Definition MemRefUtils.h:64