MLIR  20.0.0git
EmulateNarrowType.cpp
Go to the documentation of this file.
1 //===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++
2 //-*-===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/OpDefinition.h"
23 #include "llvm/Support/FormatVariadic.h"
24 #include "llvm/Support/MathExtras.h"
25 #include <cassert>
26 #include <type_traits>
27 
28 using namespace mlir;
29 
30 //===----------------------------------------------------------------------===//
31 // Utility functions
32 //===----------------------------------------------------------------------===//
33 
34 /// Converts a memref::ReinterpretCastOp to the converted type. The result
35 /// MemRefType of the old op must have a rank and stride of 1, with static
36 /// offset and size. The number of bits in the offset must evenly divide the
37 /// bitwidth of the new converted type.
38 static LogicalResult
40  memref::ReinterpretCastOp::Adaptor adaptor,
41  memref::ReinterpretCastOp op, MemRefType newTy) {
42  auto convertedElementType = newTy.getElementType();
43  auto oldElementType = op.getType().getElementType();
44  int srcBits = oldElementType.getIntOrFloatBitWidth();
45  int dstBits = convertedElementType.getIntOrFloatBitWidth();
46  if (dstBits % srcBits != 0) {
47  return rewriter.notifyMatchFailure(op,
48  "only dstBits % srcBits == 0 supported");
49  }
50 
51  // Only support stride of 1.
52  if (llvm::any_of(op.getStaticStrides(),
53  [](int64_t stride) { return stride != 1; })) {
54  return rewriter.notifyMatchFailure(op->getLoc(),
55  "stride != 1 is not supported");
56  }
57 
58  auto sizes = op.getStaticSizes();
59  int64_t offset = op.getStaticOffset(0);
60  // Only support static sizes and offsets.
61  if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
62  offset == ShapedType::kDynamic) {
63  return rewriter.notifyMatchFailure(
64  op, "dynamic size or offset is not supported");
65  }
66 
67  int elementsPerByte = dstBits / srcBits;
68  if (offset % elementsPerByte != 0) {
69  return rewriter.notifyMatchFailure(
70  op, "offset not multiple of elementsPerByte is not supported");
71  }
72 
74  if (sizes.size())
75  size.push_back(llvm::divideCeilSigned(sizes[0], elementsPerByte));
76  offset = offset / elementsPerByte;
77 
78  rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
79  op, newTy, adaptor.getSource(), offset, size, op.getStaticStrides());
80  return success();
81 }
82 
83 /// When data is loaded/stored in `targetBits` granularity, but is used in
84 /// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
85 /// treated as an array of elements of width `sourceBits`.
86 /// Return the bit offset of the value at position `srcIdx`. For example, if
87 /// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
88 /// located at (x % 2) * 4. Because there are two elements in one i8, and one
89 /// element has 4 bits.
91  int sourceBits, int targetBits,
92  OpBuilder &builder) {
93  assert(targetBits % sourceBits == 0);
94  AffineExpr s0;
95  bindSymbols(builder.getContext(), s0);
96  int scaleFactor = targetBits / sourceBits;
97  AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
98  OpFoldResult offsetVal =
99  affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
100  Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
101  IntegerType dstType = builder.getIntegerType(targetBits);
102  return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
103 }
104 
105 /// When writing a subbyte size, masked bitwise operations are used to only
106 /// modify the relevant bits. This function returns an and mask for clearing
107 /// the destination bits in a subbyte write. E.g., when writing to the second
108 /// i4 in an i32, 0xFFFFFF0F is created.
109 static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices,
110  int64_t srcBits, int64_t dstBits,
111  Value bitwidthOffset, OpBuilder &builder) {
112  auto dstIntegerType = builder.getIntegerType(dstBits);
113  auto maskRightAlignedAttr =
114  builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
115  Value maskRightAligned = builder.create<arith::ConstantOp>(
116  loc, dstIntegerType, maskRightAlignedAttr);
117  Value writeMaskInverse =
118  builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
119  auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
120  Value flipVal =
121  builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr);
122  return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
123 }
124 
125 /// Returns the scaled linearized index based on the `srcBits` and `dstBits`
126 /// sizes. The input `linearizedIndex` has the granularity of `srcBits`, and
127 /// the returned index has the granularity of `dstBits`
129  OpFoldResult linearizedIndex,
130  int64_t srcBits, int64_t dstBits) {
131  AffineExpr s0;
132  bindSymbols(builder.getContext(), s0);
133  int64_t scaler = dstBits / srcBits;
134  OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
135  builder, loc, s0.floorDiv(scaler), {linearizedIndex});
136  return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices);
137 }
138 
139 static OpFoldResult
140 getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
141  const SmallVector<OpFoldResult> &indices,
142  Value memref) {
143  auto stridedMetadata =
144  builder.create<memref::ExtractStridedMetadataOp>(loc, memref);
145  OpFoldResult linearizedIndices;
146  std::tie(std::ignore, linearizedIndices) =
148  builder, loc, srcBits, srcBits,
149  stridedMetadata.getConstifiedMixedOffset(),
150  stridedMetadata.getConstifiedMixedSizes(),
151  stridedMetadata.getConstifiedMixedStrides(), indices);
152  return linearizedIndices;
153 }
154 
155 namespace {
156 
157 //===----------------------------------------------------------------------===//
158 // ConvertMemRefAllocation
159 //===----------------------------------------------------------------------===//
160 
161 template <typename OpTy>
162 struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
164 
165  LogicalResult
166  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
167  ConversionPatternRewriter &rewriter) const override {
168  static_assert(std::is_same<OpTy, memref::AllocOp>() ||
169  std::is_same<OpTy, memref::AllocaOp>(),
170  "expected only memref::AllocOp or memref::AllocaOp");
171  auto currentType = cast<MemRefType>(op.getMemref().getType());
172  auto newResultType =
173  this->getTypeConverter()->template convertType<MemRefType>(
174  op.getType());
175  if (!newResultType) {
176  return rewriter.notifyMatchFailure(
177  op->getLoc(),
178  llvm::formatv("failed to convert memref type: {0}", op.getType()));
179  }
180 
181  // Special case zero-rank memrefs.
182  if (currentType.getRank() == 0) {
183  rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
184  adaptor.getSymbolOperands(),
185  adaptor.getAlignmentAttr());
186  return success();
187  }
188 
189  Location loc = op.getLoc();
190  OpFoldResult zero = rewriter.getIndexAttr(0);
191  SmallVector<OpFoldResult> indices(currentType.getRank(), zero);
192 
193  // Get linearized type.
194  int srcBits = currentType.getElementType().getIntOrFloatBitWidth();
195  int dstBits = newResultType.getElementType().getIntOrFloatBitWidth();
196  SmallVector<OpFoldResult> sizes = op.getMixedSizes();
197 
198  memref::LinearizedMemRefInfo linearizedMemRefInfo =
200  rewriter, loc, srcBits, dstBits, /*offset =*/zero, sizes);
201  SmallVector<Value> dynamicLinearizedSize;
202  if (!newResultType.hasStaticShape()) {
203  dynamicLinearizedSize.push_back(getValueOrCreateConstantIndexOp(
204  rewriter, loc, linearizedMemRefInfo.linearizedSize));
205  }
206 
207  rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
208  adaptor.getSymbolOperands(),
209  adaptor.getAlignmentAttr());
210  return success();
211  }
212 };
213 
214 //===----------------------------------------------------------------------===//
215 // ConvertMemRefAssumeAlignment
216 //===----------------------------------------------------------------------===//
217 
218 struct ConvertMemRefAssumeAlignment final
219  : OpConversionPattern<memref::AssumeAlignmentOp> {
221 
222  LogicalResult
223  matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
224  ConversionPatternRewriter &rewriter) const override {
225  Type newTy = getTypeConverter()->convertType(op.getMemref().getType());
226  if (!newTy) {
227  return rewriter.notifyMatchFailure(
228  op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
229  op.getMemref().getType()));
230  }
231 
232  rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
233  op, adaptor.getMemref(), adaptor.getAlignmentAttr());
234  return success();
235  }
236 };
237 
238 //===----------------------------------------------------------------------===//
239 // ConvertMemRefCopy
240 //===----------------------------------------------------------------------===//
241 
242 struct ConvertMemRefCopy final : OpConversionPattern<memref::CopyOp> {
244 
245  LogicalResult
246  matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
247  ConversionPatternRewriter &rewriter) const override {
248  auto maybeRankedSource = dyn_cast<MemRefType>(op.getSource().getType());
249  auto maybeRankedDest = dyn_cast<MemRefType>(op.getTarget().getType());
250  if (maybeRankedSource && maybeRankedDest &&
251  maybeRankedSource.getLayout() != maybeRankedDest.getLayout())
252  return rewriter.notifyMatchFailure(
253  op, llvm::formatv("memref.copy emulation with distinct layouts ({0} "
254  "and {1}) is currently unimplemented",
255  maybeRankedSource.getLayout(),
256  maybeRankedDest.getLayout()));
257  rewriter.replaceOpWithNewOp<memref::CopyOp>(op, adaptor.getSource(),
258  adaptor.getTarget());
259  return success();
260  }
261 };
262 
263 //===----------------------------------------------------------------------===//
264 // ConvertMemRefDealloc
265 //===----------------------------------------------------------------------===//
266 
267 struct ConvertMemRefDealloc final : OpConversionPattern<memref::DeallocOp> {
269 
270  LogicalResult
271  matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
272  ConversionPatternRewriter &rewriter) const override {
273  rewriter.replaceOpWithNewOp<memref::DeallocOp>(op, adaptor.getMemref());
274  return success();
275  }
276 };
277 
278 //===----------------------------------------------------------------------===//
279 // ConvertMemRefLoad
280 //===----------------------------------------------------------------------===//
281 
282 struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
284 
285  LogicalResult
286  matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
287  ConversionPatternRewriter &rewriter) const override {
288  auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
289  auto convertedElementType = convertedType.getElementType();
290  auto oldElementType = op.getMemRefType().getElementType();
291  int srcBits = oldElementType.getIntOrFloatBitWidth();
292  int dstBits = convertedElementType.getIntOrFloatBitWidth();
293  if (dstBits % srcBits != 0) {
294  return rewriter.notifyMatchFailure(
295  op, "only dstBits % srcBits == 0 supported");
296  }
297 
298  Location loc = op.getLoc();
299  // Special case 0-rank memref loads.
300  Value bitsLoad;
301  if (convertedType.getRank() == 0) {
302  bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
303  ValueRange{});
304  } else {
305  // Linearize the indices of the original load instruction. Do not account
306  // for the scaling yet. This will be accounted for later.
307  OpFoldResult linearizedIndices = getLinearizedSrcIndices(
308  rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
309 
310  Value newLoad = rewriter.create<memref::LoadOp>(
311  loc, adaptor.getMemref(),
312  getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
313  dstBits));
314 
315  // Get the offset and shift the bits to the rightmost.
316  // Note, currently only the big-endian is supported.
317  Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices,
318  srcBits, dstBits, rewriter);
319  bitsLoad = rewriter.create<arith::ShRSIOp>(loc, newLoad, bitwidthOffset);
320  }
321 
322  // Get the corresponding bits. If the arith computation bitwidth equals
323  // to the emulated bitwidth, we apply a mask to extract the low bits.
324  // It is not clear if this case actually happens in practice, but we keep
325  // the operations just in case. Otherwise, if the arith computation bitwidth
326  // is different from the emulated bitwidth we truncate the result.
327  Operation *result;
328  auto resultTy = getTypeConverter()->convertType(oldElementType);
329  if (resultTy == convertedElementType) {
330  auto mask = rewriter.create<arith::ConstantOp>(
331  loc, convertedElementType,
332  rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
333 
334  result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask);
335  } else {
336  result = rewriter.create<arith::TruncIOp>(loc, resultTy, bitsLoad);
337  }
338 
339  rewriter.replaceOp(op, result->getResult(0));
340  return success();
341  }
342 };
343 
344 //===----------------------------------------------------------------------===//
345 // ConvertMemRefMemorySpaceCast
346 //===----------------------------------------------------------------------===//
347 
348 struct ConvertMemRefMemorySpaceCast final
349  : OpConversionPattern<memref::MemorySpaceCastOp> {
351 
352  LogicalResult
353  matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
354  ConversionPatternRewriter &rewriter) const override {
355  Type newTy = getTypeConverter()->convertType(op.getDest().getType());
356  if (!newTy) {
357  return rewriter.notifyMatchFailure(
358  op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
359  op.getDest().getType()));
360  }
361 
362  rewriter.replaceOpWithNewOp<memref::MemorySpaceCastOp>(op, newTy,
363  adaptor.getSource());
364  return success();
365  }
366 };
367 
368 //===----------------------------------------------------------------------===//
369 // ConvertMemRefReinterpretCast
370 //===----------------------------------------------------------------------===//
371 
372 /// Output types should be at most one dimensional, so only the 0 or 1
373 /// dimensional cases are supported.
374 struct ConvertMemRefReinterpretCast final
375  : OpConversionPattern<memref::ReinterpretCastOp> {
377 
378  LogicalResult
379  matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
380  ConversionPatternRewriter &rewriter) const override {
381  MemRefType newTy =
382  getTypeConverter()->convertType<MemRefType>(op.getType());
383  if (!newTy) {
384  return rewriter.notifyMatchFailure(
385  op->getLoc(),
386  llvm::formatv("failed to convert memref type: {0}", op.getType()));
387  }
388 
389  // Only support for 0 or 1 dimensional cases.
390  if (op.getType().getRank() > 1) {
391  return rewriter.notifyMatchFailure(
392  op->getLoc(), "subview with rank > 1 is not supported");
393  }
394 
395  return convertCastingOp(rewriter, adaptor, op, newTy);
396  }
397 };
398 
399 //===----------------------------------------------------------------------===//
400 // ConvertMemrefStore
401 //===----------------------------------------------------------------------===//
402 
403 struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
405 
406  LogicalResult
407  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
408  ConversionPatternRewriter &rewriter) const override {
409  auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
410  int srcBits = op.getMemRefType().getElementTypeBitWidth();
411  int dstBits = convertedType.getElementTypeBitWidth();
412  auto dstIntegerType = rewriter.getIntegerType(dstBits);
413  if (dstBits % srcBits != 0) {
414  return rewriter.notifyMatchFailure(
415  op, "only dstBits % srcBits == 0 supported");
416  }
417 
418  Location loc = op.getLoc();
419  Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
420  adaptor.getValue());
421 
422  // Special case 0-rank memref stores. No need for masking.
423  if (convertedType.getRank() == 0) {
424  rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::assign,
425  extendedInput, adaptor.getMemref(),
426  ValueRange{});
427  rewriter.eraseOp(op);
428  return success();
429  }
430 
431  OpFoldResult linearizedIndices = getLinearizedSrcIndices(
432  rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
433  Value storeIndices = getIndicesForLoadOrStore(
434  rewriter, loc, linearizedIndices, srcBits, dstBits);
435  Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits,
436  dstBits, rewriter);
437  Value writeMask = getSubByteWriteMask(loc, linearizedIndices, srcBits,
438  dstBits, bitwidthOffset, rewriter);
439  // Align the value to write with the destination bits
440  Value alignedVal =
441  rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset);
442 
443  // Clear destination bits
444  rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
445  writeMask, adaptor.getMemref(),
446  storeIndices);
447  // Write srcs bits to destination
448  rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
449  alignedVal, adaptor.getMemref(),
450  storeIndices);
451  rewriter.eraseOp(op);
452  return success();
453  }
454 };
455 
456 //===----------------------------------------------------------------------===//
457 // ConvertMemRefSubview
458 //===----------------------------------------------------------------------===//
459 
460 /// Emulating narrow ints on subview have limited support, supporting only
461 /// static offset and size and stride of 1. Ideally, the subview should be
462 /// folded away before running narrow type emulation, and this pattern should
463 /// only run for cases that can't be folded.
464 struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
466 
467  LogicalResult
468  matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
469  ConversionPatternRewriter &rewriter) const override {
470  MemRefType newTy =
471  getTypeConverter()->convertType<MemRefType>(subViewOp.getType());
472  if (!newTy) {
473  return rewriter.notifyMatchFailure(
474  subViewOp->getLoc(),
475  llvm::formatv("failed to convert memref type: {0}",
476  subViewOp.getType()));
477  }
478 
479  Location loc = subViewOp.getLoc();
480  Type convertedElementType = newTy.getElementType();
481  Type oldElementType = subViewOp.getType().getElementType();
482  int srcBits = oldElementType.getIntOrFloatBitWidth();
483  int dstBits = convertedElementType.getIntOrFloatBitWidth();
484  if (dstBits % srcBits != 0)
485  return rewriter.notifyMatchFailure(
486  subViewOp, "only dstBits % srcBits == 0 supported");
487 
488  // Only support stride of 1.
489  if (llvm::any_of(subViewOp.getStaticStrides(),
490  [](int64_t stride) { return stride != 1; })) {
491  return rewriter.notifyMatchFailure(subViewOp->getLoc(),
492  "stride != 1 is not supported");
493  }
494 
495  if (!memref::isStaticShapeAndContiguousRowMajor(subViewOp.getType())) {
496  return rewriter.notifyMatchFailure(
497  subViewOp, "the result memref type is not contiguous");
498  }
499 
500  auto sizes = subViewOp.getStaticSizes();
501  int64_t lastOffset = subViewOp.getStaticOffsets().back();
502  // Only support static sizes and offsets.
503  if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
504  lastOffset == ShapedType::kDynamic) {
505  return rewriter.notifyMatchFailure(
506  subViewOp->getLoc(), "dynamic size or offset is not supported");
507  }
508 
509  // Transform the offsets, sizes and strides according to the emulation.
510  auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
511  loc, subViewOp.getViewSource());
512 
513  OpFoldResult linearizedIndices;
514  auto strides = stridedMetadata.getConstifiedMixedStrides();
515  memref::LinearizedMemRefInfo linearizedInfo;
516  std::tie(linearizedInfo, linearizedIndices) =
518  rewriter, loc, srcBits, dstBits,
519  stridedMetadata.getConstifiedMixedOffset(),
520  subViewOp.getMixedSizes(), strides,
521  getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(),
522  rewriter));
523 
524  rewriter.replaceOpWithNewOp<memref::SubViewOp>(
525  subViewOp, newTy, adaptor.getSource(), linearizedIndices,
526  linearizedInfo.linearizedSize, strides.back());
527  return success();
528  }
529 };
530 
531 //===----------------------------------------------------------------------===//
532 // ConvertMemRefCollapseShape
533 //===----------------------------------------------------------------------===//
534 
535 /// Emulating a `memref.collapse_shape` becomes a no-op after emulation given
536 /// that we flatten memrefs to a single dimension as part of the emulation and
537 /// there is no dimension to collapse any further.
538 struct ConvertMemRefCollapseShape final
539  : OpConversionPattern<memref::CollapseShapeOp> {
541 
542  LogicalResult
543  matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
544  ConversionPatternRewriter &rewriter) const override {
545  Value srcVal = adaptor.getSrc();
546  auto newTy = dyn_cast<MemRefType>(srcVal.getType());
547  if (!newTy)
548  return failure();
549 
550  if (newTy.getRank() != 1)
551  return failure();
552 
553  rewriter.replaceOp(collapseShapeOp, srcVal);
554  return success();
555  }
556 };
557 
558 /// Emulating a `memref.expand_shape` becomes a no-op after emulation given
559 /// that we flatten memrefs to a single dimension as part of the emulation and
560 /// the expansion would just have been undone.
561 struct ConvertMemRefExpandShape final
562  : OpConversionPattern<memref::ExpandShapeOp> {
564 
565  LogicalResult
566  matchAndRewrite(memref::ExpandShapeOp expandShapeOp, OpAdaptor adaptor,
567  ConversionPatternRewriter &rewriter) const override {
568  Value srcVal = adaptor.getSrc();
569  auto newTy = dyn_cast<MemRefType>(srcVal.getType());
570  if (!newTy)
571  return failure();
572 
573  if (newTy.getRank() != 1)
574  return failure();
575 
576  rewriter.replaceOp(expandShapeOp, srcVal);
577  return success();
578  }
579 };
580 } // end anonymous namespace
581 
582 //===----------------------------------------------------------------------===//
583 // Public Interface Definition
584 //===----------------------------------------------------------------------===//
585 
587  const arith::NarrowTypeEmulationConverter &typeConverter,
588  RewritePatternSet &patterns) {
589 
590  // Populate `memref.*` conversion patterns.
591  patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
592  ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCopy,
593  ConvertMemRefDealloc, ConvertMemRefCollapseShape,
594  ConvertMemRefExpandShape, ConvertMemRefLoad, ConvertMemrefStore,
595  ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast,
596  ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
597  typeConverter, patterns.getContext());
599 }
600 
601 static SmallVector<int64_t> getLinearizedShape(MemRefType ty, int srcBits,
602  int dstBits) {
603  if (ty.getRank() == 0)
604  return {};
605 
606  int64_t linearizedShape = 1;
607  for (auto shape : ty.getShape()) {
608  if (shape == ShapedType::kDynamic)
609  return {ShapedType::kDynamic};
610  linearizedShape *= shape;
611  }
612  int scale = dstBits / srcBits;
613  // Scale the size to the ceilDiv(linearizedShape, scale)
614  // to accomodate all the values.
615  linearizedShape = (linearizedShape + scale - 1) / scale;
616  return {linearizedShape};
617 }
618 
620  arith::NarrowTypeEmulationConverter &typeConverter) {
621  typeConverter.addConversion(
622  [&typeConverter](MemRefType ty) -> std::optional<Type> {
623  auto intTy = dyn_cast<IntegerType>(ty.getElementType());
624  if (!intTy)
625  return ty;
626 
627  unsigned width = intTy.getWidth();
628  unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth();
629  if (width >= loadStoreWidth)
630  return ty;
631 
632  // Currently only handle innermost stride being 1, checking
633  SmallVector<int64_t> strides;
634  int64_t offset;
635  if (failed(getStridesAndOffset(ty, strides, offset)))
636  return nullptr;
637  if (!strides.empty() && strides.back() != 1)
638  return nullptr;
639 
640  auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
641  intTy.getSignedness());
642  if (!newElemTy)
643  return nullptr;
644 
645  StridedLayoutAttr layoutAttr;
646  // If the offset is 0, we do not need a strided layout as the stride is
647  // 1, so we only use the strided layout if the offset is not 0.
648  if (offset != 0) {
649  if (offset == ShapedType::kDynamic) {
650  layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
651  ArrayRef<int64_t>{1});
652  } else {
653  // Check if the number of bytes are a multiple of the loadStoreWidth
654  // and if so, divide it by the loadStoreWidth to get the offset.
655  if ((offset * width) % loadStoreWidth != 0)
656  return std::nullopt;
657  offset = (offset * width) / loadStoreWidth;
658 
659  layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
660  ArrayRef<int64_t>{1});
661  }
662  }
663 
664  return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth),
665  newElemTy, layoutAttr, ty.getMemorySpace());
666  });
667 }
static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
When data is loaded/stored in targetBits granularity, but is used in sourceBits granularity (sourceBi...
static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc, OpFoldResult linearizedIndex, int64_t srcBits, int64_t dstBits)
Returns the scaled linearized index based on the srcBits and dstBits sizes.
static SmallVector< int64_t > getLinearizedShape(MemRefType ty, int srcBits, int dstBits)
static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter, memref::ReinterpretCastOp::Adaptor adaptor, memref::ReinterpretCastOp op, MemRefType newTy)
Converts a memref::ReinterpretCastOp to the converted type.
static OpFoldResult getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits, const SmallVector< OpFoldResult > &indices, Value memref)
static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices, int64_t srcBits, int64_t dstBits, Value bitwidthOffset, OpBuilder &builder)
When writing a subbyte size, masked bitwise operations are used to only modify the relevant bits.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:917
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:215
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Converts narrow integer or float types that are not supported by the target hardware to wider types.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1194
void populateMemRefNarrowTypeEmulationConversions(arith::NarrowTypeEmulationConverter &typeConverter)
Appends type conversions for emulating memref operations over narrow types with ops over wider types.
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Definition: MemRefUtils.cpp:52
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
Definition: MemRefUtils.cpp:24
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
void populateMemRefNarrowTypeEmulationPatterns(const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating memref operations over narrow types with ops over wider types.
Include the generated interface declarations.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
Definition: MemRefUtils.h:50