MLIR  20.0.0git
EmulateNarrowType.cpp
Go to the documentation of this file.
1 //===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++
2 //-*-===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/OpDefinition.h"
23 #include "llvm/Support/FormatVariadic.h"
24 #include "llvm/Support/MathExtras.h"
25 #include <cassert>
26 #include <type_traits>
27 
28 using namespace mlir;
29 
30 //===----------------------------------------------------------------------===//
31 // Utility functions
32 //===----------------------------------------------------------------------===//
33 
34 /// Converts a memref::ReinterpretCastOp to the converted type. The result
35 /// MemRefType of the old op must have a rank and stride of 1, with static
36 /// offset and size. The number of bits in the offset must evenly divide the
37 /// bitwidth of the new converted type.
38 static LogicalResult
40  memref::ReinterpretCastOp::Adaptor adaptor,
41  memref::ReinterpretCastOp op, MemRefType newTy) {
42  auto convertedElementType = newTy.getElementType();
43  auto oldElementType = op.getType().getElementType();
44  int srcBits = oldElementType.getIntOrFloatBitWidth();
45  int dstBits = convertedElementType.getIntOrFloatBitWidth();
46  if (dstBits % srcBits != 0) {
47  return rewriter.notifyMatchFailure(op,
48  "only dstBits % srcBits == 0 supported");
49  }
50 
51  // Only support stride of 1.
52  if (llvm::any_of(op.getStaticStrides(),
53  [](int64_t stride) { return stride != 1; })) {
54  return rewriter.notifyMatchFailure(op->getLoc(),
55  "stride != 1 is not supported");
56  }
57 
58  auto sizes = op.getStaticSizes();
59  int64_t offset = op.getStaticOffset(0);
60  // Only support static sizes and offsets.
61  if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
62  offset == ShapedType::kDynamic) {
63  return rewriter.notifyMatchFailure(
64  op, "dynamic size or offset is not supported");
65  }
66 
67  int elementsPerByte = dstBits / srcBits;
68  if (offset % elementsPerByte != 0) {
69  return rewriter.notifyMatchFailure(
70  op, "offset not multiple of elementsPerByte is not supported");
71  }
72 
74  if (sizes.size())
75  size.push_back(llvm::divideCeilSigned(sizes[0], elementsPerByte));
76  offset = offset / elementsPerByte;
77 
78  rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
79  op, newTy, adaptor.getSource(), offset, size, op.getStaticStrides());
80  return success();
81 }
82 
83 /// When data is loaded/stored in `targetBits` granularity, but is used in
84 /// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
85 /// treated as an array of elements of width `sourceBits`.
86 /// Return the bit offset of the value at position `srcIdx`. For example, if
87 /// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
88 /// located at (x % 2) * 4. Because there are two elements in one i8, and one
89 /// element has 4 bits.
91  int sourceBits, int targetBits,
92  OpBuilder &builder) {
93  assert(targetBits % sourceBits == 0);
94  AffineExpr s0;
95  bindSymbols(builder.getContext(), s0);
96  int scaleFactor = targetBits / sourceBits;
97  AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
98  OpFoldResult offsetVal =
99  affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
100  Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
101  IntegerType dstType = builder.getIntegerType(targetBits);
102  return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
103 }
104 
105 /// When writing a subbyte size, masked bitwise operations are used to only
106 /// modify the relevant bits. This function returns an and mask for clearing
107 /// the destination bits in a subbyte write. E.g., when writing to the second
108 /// i4 in an i32, 0xFFFFFF0F is created.
109 static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices,
110  int64_t srcBits, int64_t dstBits,
111  Value bitwidthOffset, OpBuilder &builder) {
112  auto dstIntegerType = builder.getIntegerType(dstBits);
113  auto maskRightAlignedAttr =
114  builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
115  Value maskRightAligned = builder.create<arith::ConstantOp>(
116  loc, dstIntegerType, maskRightAlignedAttr);
117  Value writeMaskInverse =
118  builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
119  auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
120  Value flipVal =
121  builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr);
122  return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
123 }
124 
125 /// Returns the scaled linearized index based on the `srcBits` and `dstBits`
126 /// sizes. The input `linearizedIndex` has the granularity of `srcBits`, and
127 /// the returned index has the granularity of `dstBits`
129  OpFoldResult linearizedIndex,
130  int64_t srcBits, int64_t dstBits) {
131  AffineExpr s0;
132  bindSymbols(builder.getContext(), s0);
133  int64_t scaler = dstBits / srcBits;
134  OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
135  builder, loc, s0.floorDiv(scaler), {linearizedIndex});
136  return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices);
137 }
138 
139 static OpFoldResult
140 getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
141  const SmallVector<OpFoldResult> &indices,
142  Value memref) {
143  auto stridedMetadata =
144  builder.create<memref::ExtractStridedMetadataOp>(loc, memref);
145  OpFoldResult linearizedIndices;
146  std::tie(std::ignore, linearizedIndices) =
148  builder, loc, srcBits, srcBits,
149  stridedMetadata.getConstifiedMixedOffset(),
150  stridedMetadata.getConstifiedMixedSizes(),
151  stridedMetadata.getConstifiedMixedStrides(), indices);
152  return linearizedIndices;
153 }
154 
155 namespace {
156 
157 //===----------------------------------------------------------------------===//
158 // ConvertMemRefAllocation
159 //===----------------------------------------------------------------------===//
160 
161 template <typename OpTy>
162 struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
164 
165  LogicalResult
166  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
167  ConversionPatternRewriter &rewriter) const override {
168  static_assert(std::is_same<OpTy, memref::AllocOp>() ||
169  std::is_same<OpTy, memref::AllocaOp>(),
170  "expected only memref::AllocOp or memref::AllocaOp");
171  auto currentType = cast<MemRefType>(op.getMemref().getType());
172  auto newResultType = dyn_cast<MemRefType>(
173  this->getTypeConverter()->convertType(op.getType()));
174  if (!newResultType) {
175  return rewriter.notifyMatchFailure(
176  op->getLoc(),
177  llvm::formatv("failed to convert memref type: {0}", op.getType()));
178  }
179 
180  // Special case zero-rank memrefs.
181  if (currentType.getRank() == 0) {
182  rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
183  adaptor.getSymbolOperands(),
184  adaptor.getAlignmentAttr());
185  return success();
186  }
187 
188  Location loc = op.getLoc();
189  OpFoldResult zero = rewriter.getIndexAttr(0);
190  SmallVector<OpFoldResult> indices(currentType.getRank(), zero);
191 
192  // Get linearized type.
193  int srcBits = currentType.getElementType().getIntOrFloatBitWidth();
194  int dstBits = newResultType.getElementType().getIntOrFloatBitWidth();
195  SmallVector<OpFoldResult> sizes = op.getMixedSizes();
196 
197  memref::LinearizedMemRefInfo linearizedMemRefInfo =
199  rewriter, loc, srcBits, dstBits, /*offset =*/zero, sizes);
200  SmallVector<Value> dynamicLinearizedSize;
201  if (!newResultType.hasStaticShape()) {
202  dynamicLinearizedSize.push_back(getValueOrCreateConstantIndexOp(
203  rewriter, loc, linearizedMemRefInfo.linearizedSize));
204  }
205 
206  rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
207  adaptor.getSymbolOperands(),
208  adaptor.getAlignmentAttr());
209  return success();
210  }
211 };
212 
213 //===----------------------------------------------------------------------===//
214 // ConvertMemRefAssumeAlignment
215 //===----------------------------------------------------------------------===//
216 
217 struct ConvertMemRefAssumeAlignment final
218  : OpConversionPattern<memref::AssumeAlignmentOp> {
220 
221  LogicalResult
222  matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
223  ConversionPatternRewriter &rewriter) const override {
224  Type newTy = getTypeConverter()->convertType(op.getMemref().getType());
225  if (!newTy) {
226  return rewriter.notifyMatchFailure(
227  op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
228  op.getMemref().getType()));
229  }
230 
231  rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
232  op, adaptor.getMemref(), adaptor.getAlignmentAttr());
233  return success();
234  }
235 };
236 
237 //===----------------------------------------------------------------------===//
238 // ConvertMemRefCopy
239 //===----------------------------------------------------------------------===//
240 
241 struct ConvertMemRefCopy final : OpConversionPattern<memref::CopyOp> {
243 
244  LogicalResult
245  matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
246  ConversionPatternRewriter &rewriter) const override {
247  auto maybeRankedSource = dyn_cast<MemRefType>(op.getSource().getType());
248  auto maybeRankedDest = dyn_cast<MemRefType>(op.getTarget().getType());
249  if (maybeRankedSource && maybeRankedDest &&
250  maybeRankedSource.getLayout() != maybeRankedDest.getLayout())
251  return rewriter.notifyMatchFailure(
252  op, llvm::formatv("memref.copy emulation with distinct layouts ({0} "
253  "and {1}) is currently unimplemented",
254  maybeRankedSource.getLayout(),
255  maybeRankedDest.getLayout()));
256  rewriter.replaceOpWithNewOp<memref::CopyOp>(op, adaptor.getSource(),
257  adaptor.getTarget());
258  return success();
259  }
260 };
261 
262 //===----------------------------------------------------------------------===//
263 // ConvertMemRefDealloc
264 //===----------------------------------------------------------------------===//
265 
266 struct ConvertMemRefDealloc final : OpConversionPattern<memref::DeallocOp> {
268 
269  LogicalResult
270  matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
271  ConversionPatternRewriter &rewriter) const override {
272  rewriter.replaceOpWithNewOp<memref::DeallocOp>(op, adaptor.getMemref());
273  return success();
274  }
275 };
276 
277 //===----------------------------------------------------------------------===//
278 // ConvertMemRefLoad
279 //===----------------------------------------------------------------------===//
280 
281 struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
283 
284  LogicalResult
285  matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
286  ConversionPatternRewriter &rewriter) const override {
287  auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
288  auto convertedElementType = convertedType.getElementType();
289  auto oldElementType = op.getMemRefType().getElementType();
290  int srcBits = oldElementType.getIntOrFloatBitWidth();
291  int dstBits = convertedElementType.getIntOrFloatBitWidth();
292  if (dstBits % srcBits != 0) {
293  return rewriter.notifyMatchFailure(
294  op, "only dstBits % srcBits == 0 supported");
295  }
296 
297  Location loc = op.getLoc();
298  // Special case 0-rank memref loads.
299  Value bitsLoad;
300  if (convertedType.getRank() == 0) {
301  bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
302  ValueRange{});
303  } else {
304  // Linearize the indices of the original load instruction. Do not account
305  // for the scaling yet. This will be accounted for later.
306  OpFoldResult linearizedIndices = getLinearizedSrcIndices(
307  rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
308 
309  Value newLoad = rewriter.create<memref::LoadOp>(
310  loc, adaptor.getMemref(),
311  getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
312  dstBits));
313 
314  // Get the offset and shift the bits to the rightmost.
315  // Note, currently only the big-endian is supported.
316  Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices,
317  srcBits, dstBits, rewriter);
318  bitsLoad = rewriter.create<arith::ShRSIOp>(loc, newLoad, bitwidthOffset);
319  }
320 
321  // Get the corresponding bits. If the arith computation bitwidth equals
322  // to the emulated bitwidth, we apply a mask to extract the low bits.
323  // It is not clear if this case actually happens in practice, but we keep
324  // the operations just in case. Otherwise, if the arith computation bitwidth
325  // is different from the emulated bitwidth we truncate the result.
326  Operation *result;
327  auto resultTy = getTypeConverter()->convertType(oldElementType);
328  if (resultTy == convertedElementType) {
329  auto mask = rewriter.create<arith::ConstantOp>(
330  loc, convertedElementType,
331  rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
332 
333  result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask);
334  } else {
335  result = rewriter.create<arith::TruncIOp>(loc, resultTy, bitsLoad);
336  }
337 
338  rewriter.replaceOp(op, result->getResult(0));
339  return success();
340  }
341 };
342 
343 //===----------------------------------------------------------------------===//
344 // ConvertMemRefMemorySpaceCast
345 //===----------------------------------------------------------------------===//
346 
347 struct ConvertMemRefMemorySpaceCast final
348  : OpConversionPattern<memref::MemorySpaceCastOp> {
350 
351  LogicalResult
352  matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
353  ConversionPatternRewriter &rewriter) const override {
354  Type newTy = getTypeConverter()->convertType(op.getDest().getType());
355  if (!newTy) {
356  return rewriter.notifyMatchFailure(
357  op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
358  op.getDest().getType()));
359  }
360 
361  rewriter.replaceOpWithNewOp<memref::MemorySpaceCastOp>(op, newTy,
362  adaptor.getSource());
363  return success();
364  }
365 };
366 
367 //===----------------------------------------------------------------------===//
368 // ConvertMemRefReinterpretCast
369 //===----------------------------------------------------------------------===//
370 
371 /// Output types should be at most one dimensional, so only the 0 or 1
372 /// dimensional cases are supported.
373 struct ConvertMemRefReinterpretCast final
374  : OpConversionPattern<memref::ReinterpretCastOp> {
376 
377  LogicalResult
378  matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
379  ConversionPatternRewriter &rewriter) const override {
380  MemRefType newTy =
381  dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
382  if (!newTy) {
383  return rewriter.notifyMatchFailure(
384  op->getLoc(),
385  llvm::formatv("failed to convert memref type: {0}", op.getType()));
386  }
387 
388  // Only support for 0 or 1 dimensional cases.
389  if (op.getType().getRank() > 1) {
390  return rewriter.notifyMatchFailure(
391  op->getLoc(), "subview with rank > 1 is not supported");
392  }
393 
394  return convertCastingOp(rewriter, adaptor, op, newTy);
395  }
396 };
397 
398 //===----------------------------------------------------------------------===//
399 // ConvertMemrefStore
400 //===----------------------------------------------------------------------===//
401 
402 struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
404 
405  LogicalResult
406  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
407  ConversionPatternRewriter &rewriter) const override {
408  auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
409  int srcBits = op.getMemRefType().getElementTypeBitWidth();
410  int dstBits = convertedType.getElementTypeBitWidth();
411  auto dstIntegerType = rewriter.getIntegerType(dstBits);
412  if (dstBits % srcBits != 0) {
413  return rewriter.notifyMatchFailure(
414  op, "only dstBits % srcBits == 0 supported");
415  }
416 
417  Location loc = op.getLoc();
418  Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
419  adaptor.getValue());
420 
421  // Special case 0-rank memref stores. No need for masking.
422  if (convertedType.getRank() == 0) {
423  rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::assign,
424  extendedInput, adaptor.getMemref(),
425  ValueRange{});
426  rewriter.eraseOp(op);
427  return success();
428  }
429 
430  OpFoldResult linearizedIndices = getLinearizedSrcIndices(
431  rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
432  Value storeIndices = getIndicesForLoadOrStore(
433  rewriter, loc, linearizedIndices, srcBits, dstBits);
434  Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits,
435  dstBits, rewriter);
436  Value writeMask = getSubByteWriteMask(loc, linearizedIndices, srcBits,
437  dstBits, bitwidthOffset, rewriter);
438  // Align the value to write with the destination bits
439  Value alignedVal =
440  rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset);
441 
442  // Clear destination bits
443  rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
444  writeMask, adaptor.getMemref(),
445  storeIndices);
446  // Write srcs bits to destination
447  rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
448  alignedVal, adaptor.getMemref(),
449  storeIndices);
450  rewriter.eraseOp(op);
451  return success();
452  }
453 };
454 
455 //===----------------------------------------------------------------------===//
456 // ConvertMemRefSubview
457 //===----------------------------------------------------------------------===//
458 
459 /// Emulating narrow ints on subview have limited support, supporting only
460 /// static offset and size and stride of 1. Ideally, the subview should be
461 /// folded away before running narrow type emulation, and this pattern should
462 /// only run for cases that can't be folded.
463 struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
465 
466  LogicalResult
467  matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
468  ConversionPatternRewriter &rewriter) const override {
469  MemRefType newTy = dyn_cast<MemRefType>(
470  getTypeConverter()->convertType(subViewOp.getType()));
471  if (!newTy) {
472  return rewriter.notifyMatchFailure(
473  subViewOp->getLoc(),
474  llvm::formatv("failed to convert memref type: {0}",
475  subViewOp.getType()));
476  }
477 
478  Location loc = subViewOp.getLoc();
479  Type convertedElementType = newTy.getElementType();
480  Type oldElementType = subViewOp.getType().getElementType();
481  int srcBits = oldElementType.getIntOrFloatBitWidth();
482  int dstBits = convertedElementType.getIntOrFloatBitWidth();
483  if (dstBits % srcBits != 0)
484  return rewriter.notifyMatchFailure(
485  subViewOp, "only dstBits % srcBits == 0 supported");
486 
487  // Only support stride of 1.
488  if (llvm::any_of(subViewOp.getStaticStrides(),
489  [](int64_t stride) { return stride != 1; })) {
490  return rewriter.notifyMatchFailure(subViewOp->getLoc(),
491  "stride != 1 is not supported");
492  }
493 
494  if (!memref::isStaticShapeAndContiguousRowMajor(subViewOp.getType())) {
495  return rewriter.notifyMatchFailure(
496  subViewOp, "the result memref type is not contiguous");
497  }
498 
499  auto sizes = subViewOp.getStaticSizes();
500  int64_t lastOffset = subViewOp.getStaticOffsets().back();
501  // Only support static sizes and offsets.
502  if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
503  lastOffset == ShapedType::kDynamic) {
504  return rewriter.notifyMatchFailure(
505  subViewOp->getLoc(), "dynamic size or offset is not supported");
506  }
507 
508  // Transform the offsets, sizes and strides according to the emulation.
509  auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
510  loc, subViewOp.getViewSource());
511 
512  OpFoldResult linearizedIndices;
513  auto strides = stridedMetadata.getConstifiedMixedStrides();
514  memref::LinearizedMemRefInfo linearizedInfo;
515  std::tie(linearizedInfo, linearizedIndices) =
517  rewriter, loc, srcBits, dstBits,
518  stridedMetadata.getConstifiedMixedOffset(),
519  subViewOp.getMixedSizes(), strides,
520  getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(),
521  rewriter));
522 
523  rewriter.replaceOpWithNewOp<memref::SubViewOp>(
524  subViewOp, newTy, adaptor.getSource(), linearizedIndices,
525  linearizedInfo.linearizedSize, strides.back());
526  return success();
527  }
528 };
529 
530 //===----------------------------------------------------------------------===//
531 // ConvertMemRefCollapseShape
532 //===----------------------------------------------------------------------===//
533 
534 /// Emulating a `memref.collapse_shape` becomes a no-op after emulation given
535 /// that we flatten memrefs to a single dimension as part of the emulation and
536 /// there is no dimension to collapse any further.
537 struct ConvertMemRefCollapseShape final
538  : OpConversionPattern<memref::CollapseShapeOp> {
540 
541  LogicalResult
542  matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
543  ConversionPatternRewriter &rewriter) const override {
544  Value srcVal = adaptor.getSrc();
545  auto newTy = dyn_cast<MemRefType>(srcVal.getType());
546  if (!newTy)
547  return failure();
548 
549  if (newTy.getRank() != 1)
550  return failure();
551 
552  rewriter.replaceOp(collapseShapeOp, srcVal);
553  return success();
554  }
555 };
556 
557 /// Emulating a `memref.expand_shape` becomes a no-op after emulation given
558 /// that we flatten memrefs to a single dimension as part of the emulation and
559 /// the expansion would just have been undone.
560 struct ConvertMemRefExpandShape final
561  : OpConversionPattern<memref::ExpandShapeOp> {
563 
564  LogicalResult
565  matchAndRewrite(memref::ExpandShapeOp expandShapeOp, OpAdaptor adaptor,
566  ConversionPatternRewriter &rewriter) const override {
567  Value srcVal = adaptor.getSrc();
568  auto newTy = dyn_cast<MemRefType>(srcVal.getType());
569  if (!newTy)
570  return failure();
571 
572  if (newTy.getRank() != 1)
573  return failure();
574 
575  rewriter.replaceOp(expandShapeOp, srcVal);
576  return success();
577  }
578 };
579 } // end anonymous namespace
580 
581 //===----------------------------------------------------------------------===//
582 // Public Interface Definition
583 //===----------------------------------------------------------------------===//
584 
587  RewritePatternSet &patterns) {
588 
589  // Populate `memref.*` conversion patterns.
590  patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
591  ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCopy,
592  ConvertMemRefDealloc, ConvertMemRefCollapseShape,
593  ConvertMemRefExpandShape, ConvertMemRefLoad, ConvertMemrefStore,
594  ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast,
595  ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
596  typeConverter, patterns.getContext());
598 }
599 
600 static SmallVector<int64_t> getLinearizedShape(MemRefType ty, int srcBits,
601  int dstBits) {
602  if (ty.getRank() == 0)
603  return {};
604 
605  int64_t linearizedShape = 1;
606  for (auto shape : ty.getShape()) {
607  if (shape == ShapedType::kDynamic)
608  return {ShapedType::kDynamic};
609  linearizedShape *= shape;
610  }
611  int scale = dstBits / srcBits;
612  // Scale the size to the ceilDiv(linearizedShape, scale)
613  // to accomodate all the values.
614  linearizedShape = (linearizedShape + scale - 1) / scale;
615  return {linearizedShape};
616 }
617 
619  arith::NarrowTypeEmulationConverter &typeConverter) {
620  typeConverter.addConversion(
621  [&typeConverter](MemRefType ty) -> std::optional<Type> {
622  auto intTy = dyn_cast<IntegerType>(ty.getElementType());
623  if (!intTy)
624  return ty;
625 
626  unsigned width = intTy.getWidth();
627  unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth();
628  if (width >= loadStoreWidth)
629  return ty;
630 
631  // Currently only handle innermost stride being 1, checking
632  SmallVector<int64_t> strides;
633  int64_t offset;
634  if (failed(getStridesAndOffset(ty, strides, offset)))
635  return std::nullopt;
636  if (!strides.empty() && strides.back() != 1)
637  return std::nullopt;
638 
639  auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
640  intTy.getSignedness());
641  if (!newElemTy)
642  return std::nullopt;
643 
644  StridedLayoutAttr layoutAttr;
645  // If the offset is 0, we do not need a strided layout as the stride is
646  // 1, so we only use the strided layout if the offset is not 0.
647  if (offset != 0) {
648  if (offset == ShapedType::kDynamic) {
649  layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
650  ArrayRef<int64_t>{1});
651  } else {
652  // Check if the number of bytes are a multiple of the loadStoreWidth
653  // and if so, divide it by the loadStoreWidth to get the offset.
654  if ((offset * width) % loadStoreWidth != 0)
655  return std::nullopt;
656  offset = (offset * width) / loadStoreWidth;
657 
658  layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
659  ArrayRef<int64_t>{1});
660  }
661  }
662 
663  return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth),
664  newElemTy, layoutAttr, ty.getMemorySpace());
665  });
666 }
static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
When data is loaded/stored in targetBits granularity, but is used in sourceBits granularity (sourceBi...
static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc, OpFoldResult linearizedIndex, int64_t srcBits, int64_t dstBits)
Returns the scaled linearized index based on the srcBits and dstBits sizes.
static SmallVector< int64_t > getLinearizedShape(MemRefType ty, int srcBits, int dstBits)
static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter, memref::ReinterpretCastOp::Adaptor adaptor, memref::ReinterpretCastOp op, MemRefType newTy)
Converts a memref::ReinterpretCastOp to the converted type.
static OpFoldResult getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits, const SmallVector< OpFoldResult > &indices, Value memref)
static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices, int64_t srcBits, int64_t dstBits, Value bitwidthOffset, OpBuilder &builder)
When writing a subbyte size, masked bitwise operations are used to only modify the relevant bits.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:907
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:132
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:246
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:95
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:211
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:472
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:127
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Converts narrow integer or float types that are not supported by the target hardware to wider types.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1192
void populateMemRefNarrowTypeEmulationPatterns(arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating memref operations over narrow types with ops over wider types.
void populateMemRefNarrowTypeEmulationConversions(arith::NarrowTypeEmulationConverter &typeConverter)
Appends type conversions for emulating memref operations over narrow types with ops over wider types.
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Definition: MemRefUtils.cpp:51
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
Definition: MemRefUtils.cpp:23
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
Include the generated interface declarations.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
For a memref with offset, sizes and strides, returns the offset and size to use for the linearized me...
Definition: MemRefUtils.h:45