MLIR 22.0.0git
EmulateNarrowType.cpp
Go to the documentation of this file.
1//===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++
2//-*-===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9
19#include "mlir/IR/Builders.h"
23#include "llvm/Support/FormatVariadic.h"
24#include "llvm/Support/MathExtras.h"
25#include <cassert>
26#include <type_traits>
27
28using namespace mlir;
29
30//===----------------------------------------------------------------------===//
31// Utility functions
32//===----------------------------------------------------------------------===//
33
34/// Converts a memref::ReinterpretCastOp to the converted type. The result
35/// MemRefType of the old op must have a rank and stride of 1, with static
36/// offset and size. The number of bits in the offset must evenly divide the
37/// bitwidth of the new converted type.
38static LogicalResult
39convertCastingOp(ConversionPatternRewriter &rewriter,
40 memref::ReinterpretCastOp::Adaptor adaptor,
41 memref::ReinterpretCastOp op, MemRefType newTy) {
42 auto convertedElementType = newTy.getElementType();
43 auto oldElementType = op.getType().getElementType();
44 int srcBits = oldElementType.getIntOrFloatBitWidth();
45 int dstBits = convertedElementType.getIntOrFloatBitWidth();
46 if (dstBits % srcBits != 0) {
47 return rewriter.notifyMatchFailure(op,
48 "only dstBits % srcBits == 0 supported");
49 }
50
51 // Only support stride of 1.
52 if (llvm::any_of(op.getStaticStrides(),
53 [](int64_t stride) { return stride != 1; })) {
54 return rewriter.notifyMatchFailure(op->getLoc(),
55 "stride != 1 is not supported");
56 }
57
58 auto sizes = op.getStaticSizes();
59 int64_t offset = op.getStaticOffset(0);
60 // Only support static sizes and offsets.
61 if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
62 offset == ShapedType::kDynamic) {
63 return rewriter.notifyMatchFailure(
64 op, "dynamic size or offset is not supported");
65 }
66
67 int elementsPerByte = dstBits / srcBits;
68 if (offset % elementsPerByte != 0) {
69 return rewriter.notifyMatchFailure(
70 op, "offset not multiple of elementsPerByte is not supported");
71 }
72
74 if (!sizes.empty())
75 size.push_back(llvm::divideCeilSigned(sizes[0], elementsPerByte));
76 offset = offset / elementsPerByte;
77
78 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
79 op, newTy, adaptor.getSource(), offset, size, op.getStaticStrides());
80 return success();
81}
82
83/// When data is loaded/stored in `targetBits` granularity, but is used in
84/// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
85/// treated as an array of elements of width `sourceBits`.
86/// Return the bit offset of the value at position `srcIdx`. For example, if
87/// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
88/// located at (x % 2) * 4. Because there are two elements in one i8, and one
89/// element has 4 bits.
91 int sourceBits, int targetBits,
92 OpBuilder &builder) {
93 assert(targetBits % sourceBits == 0);
94 AffineExpr s0;
95 bindSymbols(builder.getContext(), s0);
96 int scaleFactor = targetBits / sourceBits;
97 AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
98 OpFoldResult offsetVal =
99 affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
100 Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
101 IntegerType dstType = builder.getIntegerType(targetBits);
102 return arith::IndexCastOp::create(builder, loc, dstType, bitOffset);
103}
104
105/// When writing a subbyte size, masked bitwise operations are used to only
106/// modify the relevant bits. This function returns an and mask for clearing
107/// the destination bits in a subbyte write. E.g., when writing to the second
108/// i4 in an i32, 0xFFFFFF0F is created.
109static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices,
110 int64_t srcBits, int64_t dstBits,
111 Value bitwidthOffset, OpBuilder &builder) {
112 auto dstIntegerType = builder.getIntegerType(dstBits);
113 auto maskRightAlignedAttr =
114 builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
115 Value maskRightAligned = arith::ConstantOp::create(
116 builder, loc, dstIntegerType, maskRightAlignedAttr);
117 Value writeMaskInverse =
118 arith::ShLIOp::create(builder, loc, maskRightAligned, bitwidthOffset);
119 auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
120 Value flipVal =
121 arith::ConstantOp::create(builder, loc, dstIntegerType, flipValAttr);
122 return arith::XOrIOp::create(builder, loc, writeMaskInverse, flipVal);
123}
124
125/// Returns the scaled linearized index based on the `srcBits` and `dstBits`
126/// sizes. The input `linearizedIndex` has the granularity of `srcBits`, and
127/// the returned index has the granularity of `dstBits`
129 OpFoldResult linearizedIndex,
130 int64_t srcBits, int64_t dstBits) {
131 AffineExpr s0;
132 bindSymbols(builder.getContext(), s0);
133 int64_t scaler = dstBits / srcBits;
134 OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
135 builder, loc, s0.floorDiv(scaler), {linearizedIndex});
136 return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices);
137}
138
139static OpFoldResult
142 Value memref) {
143 auto stridedMetadata =
144 memref::ExtractStridedMetadataOp::create(builder, loc, memref);
145 OpFoldResult linearizedIndices;
146 std::tie(std::ignore, linearizedIndices) =
148 builder, loc, srcBits, srcBits,
149 stridedMetadata.getConstifiedMixedOffset(),
150 stridedMetadata.getConstifiedMixedSizes(),
151 stridedMetadata.getConstifiedMixedStrides(), indices);
152 return linearizedIndices;
153}
154
155namespace {
156
157//===----------------------------------------------------------------------===//
158// ConvertMemRefAllocation
159//===----------------------------------------------------------------------===//
160
161template <typename OpTy>
162struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
163 using OpConversionPattern<OpTy>::OpConversionPattern;
164
165 LogicalResult
166 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
167 ConversionPatternRewriter &rewriter) const override {
168 static_assert(std::is_same<OpTy, memref::AllocOp>() ||
169 std::is_same<OpTy, memref::AllocaOp>(),
170 "expected only memref::AllocOp or memref::AllocaOp");
171 auto currentType = cast<MemRefType>(op.getMemref().getType());
172 auto newResultType =
173 this->getTypeConverter()->template convertType<MemRefType>(
174 op.getType());
175 if (!newResultType) {
176 return rewriter.notifyMatchFailure(
177 op->getLoc(),
178 llvm::formatv("failed to convert memref type: {0}", op.getType()));
179 }
180
181 // Special case zero-rank memrefs.
182 if (currentType.getRank() == 0) {
183 rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
184 adaptor.getSymbolOperands(),
185 adaptor.getAlignmentAttr());
186 return success();
187 }
188
189 Location loc = op.getLoc();
190 OpFoldResult zero = rewriter.getIndexAttr(0);
191
192 // Get linearized type.
193 int srcBits = currentType.getElementType().getIntOrFloatBitWidth();
194 int dstBits = newResultType.getElementType().getIntOrFloatBitWidth();
195 SmallVector<OpFoldResult> sizes = op.getMixedSizes();
196
197 memref::LinearizedMemRefInfo linearizedMemRefInfo =
199 rewriter, loc, srcBits, dstBits, /*offset =*/zero, sizes);
200 SmallVector<Value> dynamicLinearizedSize;
201 if (!newResultType.hasStaticShape()) {
202 dynamicLinearizedSize.push_back(getValueOrCreateConstantIndexOp(
203 rewriter, loc, linearizedMemRefInfo.linearizedSize));
204 }
205
206 rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
207 adaptor.getSymbolOperands(),
208 adaptor.getAlignmentAttr());
209 return success();
210 }
211};
212
213//===----------------------------------------------------------------------===//
214// ConvertMemRefAssumeAlignment
215//===----------------------------------------------------------------------===//
216
217struct ConvertMemRefAssumeAlignment final
218 : OpConversionPattern<memref::AssumeAlignmentOp> {
219 using OpConversionPattern::OpConversionPattern;
220
221 LogicalResult
222 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
223 ConversionPatternRewriter &rewriter) const override {
224 Type newTy = getTypeConverter()->convertType(op.getMemref().getType());
225 if (!newTy) {
226 return rewriter.notifyMatchFailure(
227 op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
228 op.getMemref().getType()));
229 }
230
231 rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
232 op, newTy, adaptor.getMemref(), adaptor.getAlignmentAttr());
233 return success();
234 }
235};
236
237//===----------------------------------------------------------------------===//
238// ConvertMemRefCopy
239//===----------------------------------------------------------------------===//
240
241struct ConvertMemRefCopy final : OpConversionPattern<memref::CopyOp> {
242 using OpConversionPattern::OpConversionPattern;
243
244 LogicalResult
245 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
246 ConversionPatternRewriter &rewriter) const override {
247 auto maybeRankedSource = dyn_cast<MemRefType>(op.getSource().getType());
248 auto maybeRankedDest = dyn_cast<MemRefType>(op.getTarget().getType());
249 if (maybeRankedSource && maybeRankedDest &&
250 maybeRankedSource.getLayout() != maybeRankedDest.getLayout())
251 return rewriter.notifyMatchFailure(
252 op, llvm::formatv("memref.copy emulation with distinct layouts ({0} "
253 "and {1}) is currently unimplemented",
254 maybeRankedSource.getLayout(),
255 maybeRankedDest.getLayout()));
256 rewriter.replaceOpWithNewOp<memref::CopyOp>(op, adaptor.getSource(),
257 adaptor.getTarget());
258 return success();
259 }
260};
261
262//===----------------------------------------------------------------------===//
263// ConvertMemRefDealloc
264//===----------------------------------------------------------------------===//
265
266struct ConvertMemRefDealloc final : OpConversionPattern<memref::DeallocOp> {
267 using OpConversionPattern::OpConversionPattern;
268
269 LogicalResult
270 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
271 ConversionPatternRewriter &rewriter) const override {
272 rewriter.replaceOpWithNewOp<memref::DeallocOp>(op, adaptor.getMemref());
273 return success();
274 }
275};
276
277//===----------------------------------------------------------------------===//
278// ConvertMemRefLoad
279//===----------------------------------------------------------------------===//
280
281struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
282 using OpConversionPattern::OpConversionPattern;
283
284 LogicalResult
285 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
286 ConversionPatternRewriter &rewriter) const override {
287 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
288 auto convertedElementType = convertedType.getElementType();
289 auto oldElementType = op.getMemRefType().getElementType();
290 int srcBits = oldElementType.getIntOrFloatBitWidth();
291 int dstBits = convertedElementType.getIntOrFloatBitWidth();
292 if (dstBits % srcBits != 0) {
293 return rewriter.notifyMatchFailure(
294 op, "only dstBits % srcBits == 0 supported");
295 }
296
297 Location loc = op.getLoc();
298 // Special case 0-rank memref loads.
299 Value bitsLoad;
300 if (convertedType.getRank() == 0) {
301 bitsLoad = memref::LoadOp::create(rewriter, loc, adaptor.getMemref(),
302 ValueRange{});
303 } else {
304 // Linearize the indices of the original load instruction. Do not account
305 // for the scaling yet. This will be accounted for later.
306 OpFoldResult linearizedIndices = getLinearizedSrcIndices(
307 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
308
309 Value newLoad = memref::LoadOp::create(
310 rewriter, loc, adaptor.getMemref(),
311 getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
312 dstBits));
313
314 // Get the offset and shift the bits to the rightmost.
315 // Note, currently only the big-endian is supported.
316 Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices,
317 srcBits, dstBits, rewriter);
318 bitsLoad = arith::ShRSIOp::create(rewriter, loc, newLoad, bitwidthOffset);
319 }
320
321 // Get the corresponding bits. If the arith computation bitwidth equals
322 // to the emulated bitwidth, we apply a mask to extract the low bits.
323 // It is not clear if this case actually happens in practice, but we keep
324 // the operations just in case. Otherwise, if the arith computation bitwidth
325 // is different from the emulated bitwidth we truncate the result.
326 Value result;
327 auto resultTy = getTypeConverter()->convertType(oldElementType);
328 auto conversionTy =
329 resultTy.isInteger()
330 ? resultTy
331 : IntegerType::get(rewriter.getContext(),
332 resultTy.getIntOrFloatBitWidth());
333 if (conversionTy == convertedElementType) {
334 auto mask = arith::ConstantOp::create(
335 rewriter, loc, convertedElementType,
336 rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
337
338 result = arith::AndIOp::create(rewriter, loc, bitsLoad, mask);
339 } else {
340 result = arith::TruncIOp::create(rewriter, loc, conversionTy, bitsLoad);
341 }
342
343 if (conversionTy != resultTy) {
344 result = arith::BitcastOp::create(rewriter, loc, resultTy, result);
345 }
346
347 rewriter.replaceOp(op, result);
348 return success();
349 }
350};
351
352//===----------------------------------------------------------------------===//
353// ConvertMemRefMemorySpaceCast
354//===----------------------------------------------------------------------===//
355
356struct ConvertMemRefMemorySpaceCast final
357 : OpConversionPattern<memref::MemorySpaceCastOp> {
358 using OpConversionPattern::OpConversionPattern;
359
360 LogicalResult
361 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
362 ConversionPatternRewriter &rewriter) const override {
363 Type newTy = getTypeConverter()->convertType(op.getDest().getType());
364 if (!newTy) {
365 return rewriter.notifyMatchFailure(
366 op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
367 op.getDest().getType()));
368 }
369
370 rewriter.replaceOpWithNewOp<memref::MemorySpaceCastOp>(op, newTy,
371 adaptor.getSource());
372 return success();
373 }
374};
375
376//===----------------------------------------------------------------------===//
377// ConvertMemRefReinterpretCast
378//===----------------------------------------------------------------------===//
379
380/// Output types should be at most one dimensional, so only the 0 or 1
381/// dimensional cases are supported.
382struct ConvertMemRefReinterpretCast final
383 : OpConversionPattern<memref::ReinterpretCastOp> {
384 using OpConversionPattern::OpConversionPattern;
385
386 LogicalResult
387 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
388 ConversionPatternRewriter &rewriter) const override {
389 MemRefType newTy =
390 getTypeConverter()->convertType<MemRefType>(op.getType());
391 if (!newTy) {
392 return rewriter.notifyMatchFailure(
393 op->getLoc(),
394 llvm::formatv("failed to convert memref type: {0}", op.getType()));
395 }
396
397 // Only support for 0 or 1 dimensional cases.
398 if (op.getType().getRank() > 1) {
399 return rewriter.notifyMatchFailure(
400 op->getLoc(), "subview with rank > 1 is not supported");
401 }
402
403 return convertCastingOp(rewriter, adaptor, op, newTy);
404 }
405};
406
407//===----------------------------------------------------------------------===//
408// ConvertMemrefStore
409//===----------------------------------------------------------------------===//
410
411struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
412 using OpConversionPattern::OpConversionPattern;
413
414 LogicalResult
415 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
416 ConversionPatternRewriter &rewriter) const override {
417 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
418 int srcBits = op.getMemRefType().getElementTypeBitWidth();
419 int dstBits = convertedType.getElementTypeBitWidth();
420 auto dstIntegerType = rewriter.getIntegerType(dstBits);
421 if (dstBits % srcBits != 0) {
422 return rewriter.notifyMatchFailure(
423 op, "only dstBits % srcBits == 0 supported");
424 }
425
426 Location loc = op.getLoc();
427
428 // Pad the input value with 0s on the left.
429 Value input = adaptor.getValue();
430 if (!input.getType().isInteger()) {
431 input = arith::BitcastOp::create(
432 rewriter, loc,
433 IntegerType::get(rewriter.getContext(),
435 input);
436 }
437 Value extendedInput =
438 arith::ExtUIOp::create(rewriter, loc, dstIntegerType, input);
439
440 // Special case 0-rank memref stores. No need for masking.
441 if (convertedType.getRank() == 0) {
442 memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::assign,
443 extendedInput, adaptor.getMemref(),
444 ValueRange{});
445 rewriter.eraseOp(op);
446 return success();
447 }
448
449 OpFoldResult linearizedIndices = getLinearizedSrcIndices(
450 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
451 Value storeIndices = getIndicesForLoadOrStore(
452 rewriter, loc, linearizedIndices, srcBits, dstBits);
453 Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits,
454 dstBits, rewriter);
455 Value writeMask = getSubByteWriteMask(loc, linearizedIndices, srcBits,
456 dstBits, bitwidthOffset, rewriter);
457 // Align the value to write with the destination bits
458 Value alignedVal =
459 arith::ShLIOp::create(rewriter, loc, extendedInput, bitwidthOffset);
460
461 // Clear destination bits
462 memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::andi,
463 writeMask, adaptor.getMemref(), storeIndices);
464 // Write srcs bits to destination
465 memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::ori,
466 alignedVal, adaptor.getMemref(), storeIndices);
467 rewriter.eraseOp(op);
468 return success();
469 }
470};
471
472//===----------------------------------------------------------------------===//
473// ConvertMemRefSubview
474//===----------------------------------------------------------------------===//
475
476/// Emulating narrow ints on subview have limited support, supporting only
477/// static offset and size and stride of 1. Ideally, the subview should be
478/// folded away before running narrow type emulation, and this pattern should
479/// only run for cases that can't be folded.
480struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
481 using OpConversionPattern::OpConversionPattern;
482
483 LogicalResult
484 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
485 ConversionPatternRewriter &rewriter) const override {
486 MemRefType newTy =
487 getTypeConverter()->convertType<MemRefType>(subViewOp.getType());
488 if (!newTy) {
489 return rewriter.notifyMatchFailure(
490 subViewOp->getLoc(),
491 llvm::formatv("failed to convert memref type: {0}",
492 subViewOp.getType()));
493 }
494
495 Location loc = subViewOp.getLoc();
496 Type convertedElementType = newTy.getElementType();
497 Type oldElementType = subViewOp.getType().getElementType();
498 int srcBits = oldElementType.getIntOrFloatBitWidth();
499 int dstBits = convertedElementType.getIntOrFloatBitWidth();
500 if (dstBits % srcBits != 0)
501 return rewriter.notifyMatchFailure(
502 subViewOp, "only dstBits % srcBits == 0 supported");
503
504 // Only support stride of 1.
505 if (llvm::any_of(subViewOp.getStaticStrides(),
506 [](int64_t stride) { return stride != 1; })) {
507 return rewriter.notifyMatchFailure(subViewOp->getLoc(),
508 "stride != 1 is not supported");
509 }
510
511 if (!memref::isStaticShapeAndContiguousRowMajor(subViewOp.getType())) {
512 return rewriter.notifyMatchFailure(
513 subViewOp, "the result memref type is not contiguous");
514 }
515
516 auto sizes = subViewOp.getStaticSizes();
517 int64_t lastOffset = subViewOp.getStaticOffsets().back();
518 // Only support static sizes and offsets.
519 if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
520 lastOffset == ShapedType::kDynamic) {
521 return rewriter.notifyMatchFailure(
522 subViewOp->getLoc(), "dynamic size or offset is not supported");
523 }
524
525 // Transform the offsets, sizes and strides according to the emulation.
526 auto stridedMetadata = memref::ExtractStridedMetadataOp::create(
527 rewriter, loc, subViewOp.getViewSource());
528
529 OpFoldResult linearizedIndices;
530 auto strides = stridedMetadata.getConstifiedMixedStrides();
531 memref::LinearizedMemRefInfo linearizedInfo;
532 std::tie(linearizedInfo, linearizedIndices) =
534 rewriter, loc, srcBits, dstBits,
535 stridedMetadata.getConstifiedMixedOffset(),
536 subViewOp.getMixedSizes(), strides,
537 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(),
538 rewriter));
539
540 rewriter.replaceOpWithNewOp<memref::SubViewOp>(
541 subViewOp, newTy, adaptor.getSource(), linearizedIndices,
542 linearizedInfo.linearizedSize, strides.back());
543 return success();
544 }
545};
546
547//===----------------------------------------------------------------------===//
548// ConvertMemRefCollapseShape
549//===----------------------------------------------------------------------===//
550
551/// Emulating a `memref.collapse_shape` becomes a no-op after emulation given
552/// that we flatten memrefs to a single dimension as part of the emulation and
553/// there is no dimension to collapse any further.
554struct ConvertMemRefCollapseShape final
555 : OpConversionPattern<memref::CollapseShapeOp> {
556 using OpConversionPattern::OpConversionPattern;
557
558 LogicalResult
559 matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
560 ConversionPatternRewriter &rewriter) const override {
561 Value srcVal = adaptor.getSrc();
562 auto newTy = dyn_cast<MemRefType>(srcVal.getType());
563 if (!newTy)
564 return failure();
565
566 if (newTy.getRank() != 1)
567 return failure();
568
569 rewriter.replaceOp(collapseShapeOp, srcVal);
570 return success();
571 }
572};
573
574/// Emulating a `memref.expand_shape` becomes a no-op after emulation given
575/// that we flatten memrefs to a single dimension as part of the emulation and
576/// the expansion would just have been undone.
577struct ConvertMemRefExpandShape final
578 : OpConversionPattern<memref::ExpandShapeOp> {
579 using OpConversionPattern::OpConversionPattern;
580
581 LogicalResult
582 matchAndRewrite(memref::ExpandShapeOp expandShapeOp, OpAdaptor adaptor,
583 ConversionPatternRewriter &rewriter) const override {
584 Value srcVal = adaptor.getSrc();
585 auto newTy = dyn_cast<MemRefType>(srcVal.getType());
586 if (!newTy)
587 return failure();
588
589 if (newTy.getRank() != 1)
590 return failure();
591
592 rewriter.replaceOp(expandShapeOp, srcVal);
593 return success();
594 }
595};
596} // end anonymous namespace
597
598//===----------------------------------------------------------------------===//
599// Public Interface Definition
600//===----------------------------------------------------------------------===//
601
603 const arith::NarrowTypeEmulationConverter &typeConverter,
605
606 // Populate `memref.*` conversion patterns.
607 patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
608 ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCopy,
609 ConvertMemRefDealloc, ConvertMemRefCollapseShape,
610 ConvertMemRefExpandShape, ConvertMemRefLoad, ConvertMemrefStore,
611 ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast,
612 ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
613 typeConverter, patterns.getContext());
615}
616
617static SmallVector<int64_t> getLinearizedShape(MemRefType ty, int srcBits,
618 int dstBits) {
619 if (ty.getRank() == 0)
620 return {};
621
622 int64_t linearizedShape = 1;
623 for (auto shape : ty.getShape()) {
624 if (shape == ShapedType::kDynamic)
625 return {ShapedType::kDynamic};
626 linearizedShape *= shape;
627 }
628 int scale = dstBits / srcBits;
629 // Scale the size to the ceilDiv(linearizedShape, scale)
630 // to accomodate all the values.
631 linearizedShape = (linearizedShape + scale - 1) / scale;
632 return {linearizedShape};
633}
634
637 typeConverter.addConversion(
638 [&typeConverter](MemRefType ty) -> std::optional<Type> {
639 Type elementType = ty.getElementType();
640 if (!elementType.isIntOrFloat())
641 return ty;
642
643 unsigned width = elementType.getIntOrFloatBitWidth();
644 unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth();
645 if (width >= loadStoreWidth)
646 return ty;
647
648 // Currently only handle innermost stride being 1, checking
649 SmallVector<int64_t> strides;
650 int64_t offset;
651 if (failed(ty.getStridesAndOffset(strides, offset)))
652 return nullptr;
653 if (!strides.empty() && strides.back() != 1)
654 return nullptr;
655
656 auto newElemTy = IntegerType::get(
657 ty.getContext(), loadStoreWidth,
658 elementType.isInteger()
659 ? cast<IntegerType>(elementType).getSignedness()
660 : IntegerType::SignednessSemantics::Signless);
661 if (!newElemTy)
662 return nullptr;
663
664 StridedLayoutAttr layoutAttr;
665 // If the offset is 0, we do not need a strided layout as the stride is
666 // 1, so we only use the strided layout if the offset is not 0.
667 if (offset != 0) {
668 if (offset == ShapedType::kDynamic) {
669 layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
671 } else {
672 // Check if the number of bytes are a multiple of the loadStoreWidth
673 // and if so, divide it by the loadStoreWidth to get the offset.
674 if ((offset * width) % loadStoreWidth != 0)
675 return std::nullopt;
676 offset = (offset * width) / loadStoreWidth;
677
678 layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
680 }
681 }
682
683 return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth),
684 newElemTy, layoutAttr, ty.getMemorySpace());
685 });
686}
return success()
static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
When data is loaded/stored in targetBits granularity, but is used in sourceBits granularity (sourceBi...
static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc, OpFoldResult linearizedIndex, int64_t srcBits, int64_t dstBits)
Returns the scaled linearized index based on the srcBits and dstBits sizes.
static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter, memref::ReinterpretCastOp::Adaptor adaptor, memref::ReinterpretCastOp op, MemRefType newTy)
Converts a memref::ReinterpretCastOp to the converted type.
static OpFoldResult getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits, const SmallVector< OpFoldResult > &indices, Value memref)
static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices, int64_t srcBits, int64_t dstBits, Value bitwidthOffset, OpBuilder &builder)
When writing a subbyte size, masked bitwise operations are used to only modify the relevant bits.
static SmallVector< int64_t > getLinearizedShape(MemRefType ty, int srcBits, int dstBits)
Base type for affine expression.
Definition AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
MLIRContext * getContext() const
Definition Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
This class represents a single result from folding an operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Converts narrow integer or float types that are not supported by the target hardware to wider types.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateMemRefNarrowTypeEmulationConversions(arith::NarrowTypeEmulationConverter &typeConverter)
Appends type conversions for emulating memref operations over narrow types with ops over wider types.
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
void populateMemRefNarrowTypeEmulationPatterns(const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating memref operations over narrow types with ops over wider types.
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111