MLIR  22.0.0git
EmulateNarrowType.cpp
Go to the documentation of this file.
1 //===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++
2 //-*-===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/OpDefinition.h"
23 #include "llvm/Support/FormatVariadic.h"
24 #include "llvm/Support/MathExtras.h"
25 #include <cassert>
26 #include <type_traits>
27 
28 using namespace mlir;
29 
30 //===----------------------------------------------------------------------===//
31 // Utility functions
32 //===----------------------------------------------------------------------===//
33 
34 /// Converts a memref::ReinterpretCastOp to the converted type. The result
35 /// MemRefType of the old op must have a rank and stride of 1, with static
36 /// offset and size. The number of bits in the offset must evenly divide the
37 /// bitwidth of the new converted type.
38 static LogicalResult
40  memref::ReinterpretCastOp::Adaptor adaptor,
41  memref::ReinterpretCastOp op, MemRefType newTy) {
42  auto convertedElementType = newTy.getElementType();
43  auto oldElementType = op.getType().getElementType();
44  int srcBits = oldElementType.getIntOrFloatBitWidth();
45  int dstBits = convertedElementType.getIntOrFloatBitWidth();
46  if (dstBits % srcBits != 0) {
47  return rewriter.notifyMatchFailure(op,
48  "only dstBits % srcBits == 0 supported");
49  }
50 
51  // Only support stride of 1.
52  if (llvm::any_of(op.getStaticStrides(),
53  [](int64_t stride) { return stride != 1; })) {
54  return rewriter.notifyMatchFailure(op->getLoc(),
55  "stride != 1 is not supported");
56  }
57 
58  auto sizes = op.getStaticSizes();
59  int64_t offset = op.getStaticOffset(0);
60  // Only support static sizes and offsets.
61  if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
62  offset == ShapedType::kDynamic) {
63  return rewriter.notifyMatchFailure(
64  op, "dynamic size or offset is not supported");
65  }
66 
67  int elementsPerByte = dstBits / srcBits;
68  if (offset % elementsPerByte != 0) {
69  return rewriter.notifyMatchFailure(
70  op, "offset not multiple of elementsPerByte is not supported");
71  }
72 
74  if (sizes.size())
75  size.push_back(llvm::divideCeilSigned(sizes[0], elementsPerByte));
76  offset = offset / elementsPerByte;
77 
78  rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
79  op, newTy, adaptor.getSource(), offset, size, op.getStaticStrides());
80  return success();
81 }
82 
83 /// When data is loaded/stored in `targetBits` granularity, but is used in
84 /// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
85 /// treated as an array of elements of width `sourceBits`.
86 /// Return the bit offset of the value at position `srcIdx`. For example, if
87 /// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
88 /// located at (x % 2) * 4. Because there are two elements in one i8, and one
89 /// element has 4 bits.
91  int sourceBits, int targetBits,
92  OpBuilder &builder) {
93  assert(targetBits % sourceBits == 0);
94  AffineExpr s0;
95  bindSymbols(builder.getContext(), s0);
96  int scaleFactor = targetBits / sourceBits;
97  AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
98  OpFoldResult offsetVal =
99  affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
100  Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
101  IntegerType dstType = builder.getIntegerType(targetBits);
102  return arith::IndexCastOp::create(builder, loc, dstType, bitOffset);
103 }
104 
105 /// When writing a subbyte size, masked bitwise operations are used to only
106 /// modify the relevant bits. This function returns an and mask for clearing
107 /// the destination bits in a subbyte write. E.g., when writing to the second
108 /// i4 in an i32, 0xFFFFFF0F is created.
109 static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices,
110  int64_t srcBits, int64_t dstBits,
111  Value bitwidthOffset, OpBuilder &builder) {
112  auto dstIntegerType = builder.getIntegerType(dstBits);
113  auto maskRightAlignedAttr =
114  builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
115  Value maskRightAligned = arith::ConstantOp::create(
116  builder, loc, dstIntegerType, maskRightAlignedAttr);
117  Value writeMaskInverse =
118  arith::ShLIOp::create(builder, loc, maskRightAligned, bitwidthOffset);
119  auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
120  Value flipVal =
121  arith::ConstantOp::create(builder, loc, dstIntegerType, flipValAttr);
122  return arith::XOrIOp::create(builder, loc, writeMaskInverse, flipVal);
123 }
124 
125 /// Returns the scaled linearized index based on the `srcBits` and `dstBits`
126 /// sizes. The input `linearizedIndex` has the granularity of `srcBits`, and
127 /// the returned index has the granularity of `dstBits`
129  OpFoldResult linearizedIndex,
130  int64_t srcBits, int64_t dstBits) {
131  AffineExpr s0;
132  bindSymbols(builder.getContext(), s0);
133  int64_t scaler = dstBits / srcBits;
134  OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
135  builder, loc, s0.floorDiv(scaler), {linearizedIndex});
136  return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices);
137 }
138 
139 static OpFoldResult
140 getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
141  const SmallVector<OpFoldResult> &indices,
142  Value memref) {
143  auto stridedMetadata =
144  memref::ExtractStridedMetadataOp::create(builder, loc, memref);
145  OpFoldResult linearizedIndices;
146  std::tie(std::ignore, linearizedIndices) =
148  builder, loc, srcBits, srcBits,
149  stridedMetadata.getConstifiedMixedOffset(),
150  stridedMetadata.getConstifiedMixedSizes(),
151  stridedMetadata.getConstifiedMixedStrides(), indices);
152  return linearizedIndices;
153 }
154 
155 namespace {
156 
157 //===----------------------------------------------------------------------===//
158 // ConvertMemRefAllocation
159 //===----------------------------------------------------------------------===//
160 
161 template <typename OpTy>
162 struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
164 
165  LogicalResult
166  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
167  ConversionPatternRewriter &rewriter) const override {
168  static_assert(std::is_same<OpTy, memref::AllocOp>() ||
169  std::is_same<OpTy, memref::AllocaOp>(),
170  "expected only memref::AllocOp or memref::AllocaOp");
171  auto currentType = cast<MemRefType>(op.getMemref().getType());
172  auto newResultType =
173  this->getTypeConverter()->template convertType<MemRefType>(
174  op.getType());
175  if (!newResultType) {
176  return rewriter.notifyMatchFailure(
177  op->getLoc(),
178  llvm::formatv("failed to convert memref type: {0}", op.getType()));
179  }
180 
181  // Special case zero-rank memrefs.
182  if (currentType.getRank() == 0) {
183  rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
184  adaptor.getSymbolOperands(),
185  adaptor.getAlignmentAttr());
186  return success();
187  }
188 
189  Location loc = op.getLoc();
190  OpFoldResult zero = rewriter.getIndexAttr(0);
191 
192  // Get linearized type.
193  int srcBits = currentType.getElementType().getIntOrFloatBitWidth();
194  int dstBits = newResultType.getElementType().getIntOrFloatBitWidth();
195  SmallVector<OpFoldResult> sizes = op.getMixedSizes();
196 
197  memref::LinearizedMemRefInfo linearizedMemRefInfo =
199  rewriter, loc, srcBits, dstBits, /*offset =*/zero, sizes);
200  SmallVector<Value> dynamicLinearizedSize;
201  if (!newResultType.hasStaticShape()) {
202  dynamicLinearizedSize.push_back(getValueOrCreateConstantIndexOp(
203  rewriter, loc, linearizedMemRefInfo.linearizedSize));
204  }
205 
206  rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
207  adaptor.getSymbolOperands(),
208  adaptor.getAlignmentAttr());
209  return success();
210  }
211 };
212 
213 //===----------------------------------------------------------------------===//
214 // ConvertMemRefAssumeAlignment
215 //===----------------------------------------------------------------------===//
216 
217 struct ConvertMemRefAssumeAlignment final
218  : OpConversionPattern<memref::AssumeAlignmentOp> {
220 
221  LogicalResult
222  matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
223  ConversionPatternRewriter &rewriter) const override {
224  Type newTy = getTypeConverter()->convertType(op.getMemref().getType());
225  if (!newTy) {
226  return rewriter.notifyMatchFailure(
227  op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
228  op.getMemref().getType()));
229  }
230 
231  rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
232  op, newTy, adaptor.getMemref(), adaptor.getAlignmentAttr());
233  return success();
234  }
235 };
236 
237 //===----------------------------------------------------------------------===//
238 // ConvertMemRefCopy
239 //===----------------------------------------------------------------------===//
240 
241 struct ConvertMemRefCopy final : OpConversionPattern<memref::CopyOp> {
243 
244  LogicalResult
245  matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
246  ConversionPatternRewriter &rewriter) const override {
247  auto maybeRankedSource = dyn_cast<MemRefType>(op.getSource().getType());
248  auto maybeRankedDest = dyn_cast<MemRefType>(op.getTarget().getType());
249  if (maybeRankedSource && maybeRankedDest &&
250  maybeRankedSource.getLayout() != maybeRankedDest.getLayout())
251  return rewriter.notifyMatchFailure(
252  op, llvm::formatv("memref.copy emulation with distinct layouts ({0} "
253  "and {1}) is currently unimplemented",
254  maybeRankedSource.getLayout(),
255  maybeRankedDest.getLayout()));
256  rewriter.replaceOpWithNewOp<memref::CopyOp>(op, adaptor.getSource(),
257  adaptor.getTarget());
258  return success();
259  }
260 };
261 
262 //===----------------------------------------------------------------------===//
263 // ConvertMemRefDealloc
264 //===----------------------------------------------------------------------===//
265 
266 struct ConvertMemRefDealloc final : OpConversionPattern<memref::DeallocOp> {
268 
269  LogicalResult
270  matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
271  ConversionPatternRewriter &rewriter) const override {
272  rewriter.replaceOpWithNewOp<memref::DeallocOp>(op, adaptor.getMemref());
273  return success();
274  }
275 };
276 
277 //===----------------------------------------------------------------------===//
278 // ConvertMemRefLoad
279 //===----------------------------------------------------------------------===//
280 
281 struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
283 
284  LogicalResult
285  matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
286  ConversionPatternRewriter &rewriter) const override {
287  auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
288  auto convertedElementType = convertedType.getElementType();
289  auto oldElementType = op.getMemRefType().getElementType();
290  int srcBits = oldElementType.getIntOrFloatBitWidth();
291  int dstBits = convertedElementType.getIntOrFloatBitWidth();
292  if (dstBits % srcBits != 0) {
293  return rewriter.notifyMatchFailure(
294  op, "only dstBits % srcBits == 0 supported");
295  }
296 
297  Location loc = op.getLoc();
298  // Special case 0-rank memref loads.
299  Value bitsLoad;
300  if (convertedType.getRank() == 0) {
301  bitsLoad = memref::LoadOp::create(rewriter, loc, adaptor.getMemref(),
302  ValueRange{});
303  } else {
304  // Linearize the indices of the original load instruction. Do not account
305  // for the scaling yet. This will be accounted for later.
306  OpFoldResult linearizedIndices = getLinearizedSrcIndices(
307  rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
308 
309  Value newLoad = memref::LoadOp::create(
310  rewriter, loc, adaptor.getMemref(),
311  getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
312  dstBits));
313 
314  // Get the offset and shift the bits to the rightmost.
315  // Note, currently only the big-endian is supported.
316  Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices,
317  srcBits, dstBits, rewriter);
318  bitsLoad = arith::ShRSIOp::create(rewriter, loc, newLoad, bitwidthOffset);
319  }
320 
321  // Get the corresponding bits. If the arith computation bitwidth equals
322  // to the emulated bitwidth, we apply a mask to extract the low bits.
323  // It is not clear if this case actually happens in practice, but we keep
324  // the operations just in case. Otherwise, if the arith computation bitwidth
325  // is different from the emulated bitwidth we truncate the result.
326  Value result;
327  auto resultTy = getTypeConverter()->convertType(oldElementType);
328  auto conversionTy =
329  resultTy.isInteger()
330  ? resultTy
331  : IntegerType::get(rewriter.getContext(),
332  resultTy.getIntOrFloatBitWidth());
333  if (conversionTy == convertedElementType) {
334  auto mask = arith::ConstantOp::create(
335  rewriter, loc, convertedElementType,
336  rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
337 
338  result = arith::AndIOp::create(rewriter, loc, bitsLoad, mask);
339  } else {
340  result = arith::TruncIOp::create(rewriter, loc, conversionTy, bitsLoad);
341  }
342 
343  if (conversionTy != resultTy) {
344  result = arith::BitcastOp::create(rewriter, loc, resultTy, result);
345  }
346 
347  rewriter.replaceOp(op, result);
348  return success();
349  }
350 };
351 
352 //===----------------------------------------------------------------------===//
353 // ConvertMemRefMemorySpaceCast
354 //===----------------------------------------------------------------------===//
355 
356 struct ConvertMemRefMemorySpaceCast final
357  : OpConversionPattern<memref::MemorySpaceCastOp> {
359 
360  LogicalResult
361  matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
362  ConversionPatternRewriter &rewriter) const override {
363  Type newTy = getTypeConverter()->convertType(op.getDest().getType());
364  if (!newTy) {
365  return rewriter.notifyMatchFailure(
366  op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
367  op.getDest().getType()));
368  }
369 
370  rewriter.replaceOpWithNewOp<memref::MemorySpaceCastOp>(op, newTy,
371  adaptor.getSource());
372  return success();
373  }
374 };
375 
376 //===----------------------------------------------------------------------===//
377 // ConvertMemRefReinterpretCast
378 //===----------------------------------------------------------------------===//
379 
380 /// Output types should be at most one dimensional, so only the 0 or 1
381 /// dimensional cases are supported.
382 struct ConvertMemRefReinterpretCast final
383  : OpConversionPattern<memref::ReinterpretCastOp> {
385 
386  LogicalResult
387  matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
388  ConversionPatternRewriter &rewriter) const override {
389  MemRefType newTy =
390  getTypeConverter()->convertType<MemRefType>(op.getType());
391  if (!newTy) {
392  return rewriter.notifyMatchFailure(
393  op->getLoc(),
394  llvm::formatv("failed to convert memref type: {0}", op.getType()));
395  }
396 
397  // Only support for 0 or 1 dimensional cases.
398  if (op.getType().getRank() > 1) {
399  return rewriter.notifyMatchFailure(
400  op->getLoc(), "subview with rank > 1 is not supported");
401  }
402 
403  return convertCastingOp(rewriter, adaptor, op, newTy);
404  }
405 };
406 
407 //===----------------------------------------------------------------------===//
408 // ConvertMemrefStore
409 //===----------------------------------------------------------------------===//
410 
411 struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
413 
414  LogicalResult
415  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
416  ConversionPatternRewriter &rewriter) const override {
417  auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
418  int srcBits = op.getMemRefType().getElementTypeBitWidth();
419  int dstBits = convertedType.getElementTypeBitWidth();
420  auto dstIntegerType = rewriter.getIntegerType(dstBits);
421  if (dstBits % srcBits != 0) {
422  return rewriter.notifyMatchFailure(
423  op, "only dstBits % srcBits == 0 supported");
424  }
425 
426  Location loc = op.getLoc();
427 
428  // Pad the input value with 0s on the left.
429  Value input = adaptor.getValue();
430  if (!input.getType().isInteger()) {
431  input = arith::BitcastOp::create(
432  rewriter, loc,
433  IntegerType::get(rewriter.getContext(),
434  input.getType().getIntOrFloatBitWidth()),
435  input);
436  }
437  Value extendedInput =
438  arith::ExtUIOp::create(rewriter, loc, dstIntegerType, input);
439 
440  // Special case 0-rank memref stores. No need for masking.
441  if (convertedType.getRank() == 0) {
442  memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::assign,
443  extendedInput, adaptor.getMemref(),
444  ValueRange{});
445  rewriter.eraseOp(op);
446  return success();
447  }
448 
449  OpFoldResult linearizedIndices = getLinearizedSrcIndices(
450  rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
451  Value storeIndices = getIndicesForLoadOrStore(
452  rewriter, loc, linearizedIndices, srcBits, dstBits);
453  Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits,
454  dstBits, rewriter);
455  Value writeMask = getSubByteWriteMask(loc, linearizedIndices, srcBits,
456  dstBits, bitwidthOffset, rewriter);
457  // Align the value to write with the destination bits
458  Value alignedVal =
459  arith::ShLIOp::create(rewriter, loc, extendedInput, bitwidthOffset);
460 
461  // Clear destination bits
462  memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::andi,
463  writeMask, adaptor.getMemref(), storeIndices);
464  // Write srcs bits to destination
465  memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::ori,
466  alignedVal, adaptor.getMemref(), storeIndices);
467  rewriter.eraseOp(op);
468  return success();
469  }
470 };
471 
472 //===----------------------------------------------------------------------===//
473 // ConvertMemRefSubview
474 //===----------------------------------------------------------------------===//
475 
476 /// Emulating narrow ints on subview have limited support, supporting only
477 /// static offset and size and stride of 1. Ideally, the subview should be
478 /// folded away before running narrow type emulation, and this pattern should
479 /// only run for cases that can't be folded.
480 struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
482 
483  LogicalResult
484  matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
485  ConversionPatternRewriter &rewriter) const override {
486  MemRefType newTy =
487  getTypeConverter()->convertType<MemRefType>(subViewOp.getType());
488  if (!newTy) {
489  return rewriter.notifyMatchFailure(
490  subViewOp->getLoc(),
491  llvm::formatv("failed to convert memref type: {0}",
492  subViewOp.getType()));
493  }
494 
495  Location loc = subViewOp.getLoc();
496  Type convertedElementType = newTy.getElementType();
497  Type oldElementType = subViewOp.getType().getElementType();
498  int srcBits = oldElementType.getIntOrFloatBitWidth();
499  int dstBits = convertedElementType.getIntOrFloatBitWidth();
500  if (dstBits % srcBits != 0)
501  return rewriter.notifyMatchFailure(
502  subViewOp, "only dstBits % srcBits == 0 supported");
503 
504  // Only support stride of 1.
505  if (llvm::any_of(subViewOp.getStaticStrides(),
506  [](int64_t stride) { return stride != 1; })) {
507  return rewriter.notifyMatchFailure(subViewOp->getLoc(),
508  "stride != 1 is not supported");
509  }
510 
511  if (!memref::isStaticShapeAndContiguousRowMajor(subViewOp.getType())) {
512  return rewriter.notifyMatchFailure(
513  subViewOp, "the result memref type is not contiguous");
514  }
515 
516  auto sizes = subViewOp.getStaticSizes();
517  int64_t lastOffset = subViewOp.getStaticOffsets().back();
518  // Only support static sizes and offsets.
519  if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
520  lastOffset == ShapedType::kDynamic) {
521  return rewriter.notifyMatchFailure(
522  subViewOp->getLoc(), "dynamic size or offset is not supported");
523  }
524 
525  // Transform the offsets, sizes and strides according to the emulation.
526  auto stridedMetadata = memref::ExtractStridedMetadataOp::create(
527  rewriter, loc, subViewOp.getViewSource());
528 
529  OpFoldResult linearizedIndices;
530  auto strides = stridedMetadata.getConstifiedMixedStrides();
531  memref::LinearizedMemRefInfo linearizedInfo;
532  std::tie(linearizedInfo, linearizedIndices) =
534  rewriter, loc, srcBits, dstBits,
535  stridedMetadata.getConstifiedMixedOffset(),
536  subViewOp.getMixedSizes(), strides,
537  getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(),
538  rewriter));
539 
540  rewriter.replaceOpWithNewOp<memref::SubViewOp>(
541  subViewOp, newTy, adaptor.getSource(), linearizedIndices,
542  linearizedInfo.linearizedSize, strides.back());
543  return success();
544  }
545 };
546 
547 //===----------------------------------------------------------------------===//
548 // ConvertMemRefCollapseShape
549 //===----------------------------------------------------------------------===//
550 
551 /// Emulating a `memref.collapse_shape` becomes a no-op after emulation given
552 /// that we flatten memrefs to a single dimension as part of the emulation and
553 /// there is no dimension to collapse any further.
554 struct ConvertMemRefCollapseShape final
555  : OpConversionPattern<memref::CollapseShapeOp> {
557 
558  LogicalResult
559  matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
560  ConversionPatternRewriter &rewriter) const override {
561  Value srcVal = adaptor.getSrc();
562  auto newTy = dyn_cast<MemRefType>(srcVal.getType());
563  if (!newTy)
564  return failure();
565 
566  if (newTy.getRank() != 1)
567  return failure();
568 
569  rewriter.replaceOp(collapseShapeOp, srcVal);
570  return success();
571  }
572 };
573 
574 /// Emulating a `memref.expand_shape` becomes a no-op after emulation given
575 /// that we flatten memrefs to a single dimension as part of the emulation and
576 /// the expansion would just have been undone.
577 struct ConvertMemRefExpandShape final
578  : OpConversionPattern<memref::ExpandShapeOp> {
580 
581  LogicalResult
582  matchAndRewrite(memref::ExpandShapeOp expandShapeOp, OpAdaptor adaptor,
583  ConversionPatternRewriter &rewriter) const override {
584  Value srcVal = adaptor.getSrc();
585  auto newTy = dyn_cast<MemRefType>(srcVal.getType());
586  if (!newTy)
587  return failure();
588 
589  if (newTy.getRank() != 1)
590  return failure();
591 
592  rewriter.replaceOp(expandShapeOp, srcVal);
593  return success();
594  }
595 };
596 } // end anonymous namespace
597 
598 //===----------------------------------------------------------------------===//
599 // Public Interface Definition
600 //===----------------------------------------------------------------------===//
601 
603  const arith::NarrowTypeEmulationConverter &typeConverter,
605 
606  // Populate `memref.*` conversion patterns.
607  patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
608  ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCopy,
609  ConvertMemRefDealloc, ConvertMemRefCollapseShape,
610  ConvertMemRefExpandShape, ConvertMemRefLoad, ConvertMemrefStore,
611  ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast,
612  ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
613  typeConverter, patterns.getContext());
615 }
616 
617 static SmallVector<int64_t> getLinearizedShape(MemRefType ty, int srcBits,
618  int dstBits) {
619  if (ty.getRank() == 0)
620  return {};
621 
622  int64_t linearizedShape = 1;
623  for (auto shape : ty.getShape()) {
624  if (shape == ShapedType::kDynamic)
625  return {ShapedType::kDynamic};
626  linearizedShape *= shape;
627  }
628  int scale = dstBits / srcBits;
629  // Scale the size to the ceilDiv(linearizedShape, scale)
630  // to accomodate all the values.
631  linearizedShape = (linearizedShape + scale - 1) / scale;
632  return {linearizedShape};
633 }
634 
636  arith::NarrowTypeEmulationConverter &typeConverter) {
637  typeConverter.addConversion(
638  [&typeConverter](MemRefType ty) -> std::optional<Type> {
639  Type elementType = ty.getElementType();
640  if (!elementType.isIntOrFloat())
641  return ty;
642 
643  unsigned width = elementType.getIntOrFloatBitWidth();
644  unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth();
645  if (width >= loadStoreWidth)
646  return ty;
647 
648  // Currently only handle innermost stride being 1, checking
649  SmallVector<int64_t> strides;
650  int64_t offset;
651  if (failed(ty.getStridesAndOffset(strides, offset)))
652  return nullptr;
653  if (!strides.empty() && strides.back() != 1)
654  return nullptr;
655 
656  auto newElemTy = IntegerType::get(
657  ty.getContext(), loadStoreWidth,
658  elementType.isInteger()
659  ? cast<IntegerType>(elementType).getSignedness()
660  : IntegerType::SignednessSemantics::Signless);
661  if (!newElemTy)
662  return nullptr;
663 
664  StridedLayoutAttr layoutAttr;
665  // If the offset is 0, we do not need a strided layout as the stride is
666  // 1, so we only use the strided layout if the offset is not 0.
667  if (offset != 0) {
668  if (offset == ShapedType::kDynamic) {
669  layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
670  ArrayRef<int64_t>{1});
671  } else {
672  // Check if the number of bytes are a multiple of the loadStoreWidth
673  // and if so, divide it by the loadStoreWidth to get the offset.
674  if ((offset * width) % loadStoreWidth != 0)
675  return std::nullopt;
676  offset = (offset * width) / loadStoreWidth;
677 
678  layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
679  ArrayRef<int64_t>{1});
680  }
681  }
682 
683  return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth),
684  newElemTy, layoutAttr, ty.getMemorySpace());
685  });
686 }
static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
When data is loaded/stored in targetBits granularity, but is used in sourceBits granularity (sourceBi...
static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc, OpFoldResult linearizedIndex, int64_t srcBits, int64_t dstBits)
Returns the scaled linearized index based on the srcBits and dstBits sizes.
static SmallVector< int64_t > getLinearizedShape(MemRefType ty, int srcBits, int dstBits)
static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter, memref::ReinterpretCastOp::Adaptor adaptor, memref::ReinterpretCastOp op, MemRefType newTy)
Converts a memref::ReinterpretCastOp to the converted type.
static OpFoldResult getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits, const SmallVector< OpFoldResult > &indices, Value memref)
static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices, int64_t srcBits, int64_t dstBits, Value bitwidthOffset, OpBuilder &builder)
When writing a subbyte size, masked bitwise operations are used to only modify the relevant bits.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:959
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:227
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
MLIRContext * getContext() const
Definition: Builders.h:56
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:207
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Converts narrow integer or float types that are not supported by the target hardware to wider types.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1329
void populateMemRefNarrowTypeEmulationConversions(arith::NarrowTypeEmulationConverter &typeConverter)
Appends type conversions for emulating memref operations over narrow types with ops over wider types.
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Definition: MemRefUtils.cpp:51
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
Definition: MemRefUtils.cpp:23
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
void populateMemRefNarrowTypeEmulationPatterns(const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating memref operations over narrow types with ops over wider types.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
Definition: MemRefUtils.h:50