MLIR  20.0.0git
VectorEmulateNarrowType.cpp
Go to the documentation of this file.
1 //===- VectorEmulateNarrowType.cpp - Narrow type emulation ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to emulate
10 // narrow types that are not supported by the target hardware, e.g. i4, using
11 // wider types, e.g. i8.
12 //
13 /// Currently, only power-of-two integer types are supported. These are
14 /// converted to wider integers that are either 8 bits wide or wider.
15 ///
16 /// TODO: Support for non-powers-of-two.
17 //===----------------------------------------------------------------------===//
18 
29 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/IR/OpDefinition.h"
31 #include "mlir/IR/TypeUtilities.h"
32 #include "mlir/IR/Value.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/MathExtras.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include <cstdint>
39 #include <optional>
40 
41 using namespace mlir;
42 
43 #define DEBUG_TYPE "vector-narrow-type-emulation"
44 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
45 #define DBGSNL() (llvm::dbgs() << "\n")
46 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
47 
48 /// Returns a compressed mask for the emulated vector. For example, when
49 /// emulating an eight-element `i8` vector with `i32` (i.e. when the source
50 /// elements span two dest elements), this method compresses `vector<8xi1>`
51 /// into `vector<2xi1>`.
52 ///
53 /// The compressed/output mask value is set iff any mask in the corresponding
54 /// `numSrcElemsPerDest` range of uncompressed/input masks is set. E.g., if
55 /// `numSrcElemsPerDest` equals to 2, and `numFrontPadElems` equals to 1, the
56 /// following mask:
57 ///
58 /// %mask = [1, 1, 0, 0, 0, 0]
59 ///
60 /// will first be padded in the front with `numFrontPadElems` zeros, and zeros
61 /// will be added in the back to make the number of elements a multiple of
62 /// `numSrcElemsPerDest` (for easier computation). The resulting mask will be:
63 ///
64 /// %mask = [0, 1, 1, 0, 0, 0, 0, 0]
65 ///
66 /// then it will return the following new compressed mask:
67 ///
68 /// %mask = [1, 1, 0, 0]
69 ///
70 /// NOTE: `numFrontPadElems` is assumed to be strictly smaller than
71 /// `numSrcElemsPerDest`.
72 static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
73  Location loc, Value mask,
74  int numSrcElems,
75  int numSrcElemsPerDest,
76  int numFrontPadElems = 0) {
77 
78  assert(numFrontPadElems < numSrcElemsPerDest &&
79  "numFrontPadElems must be less than numSrcElemsPerDest");
80 
81  auto numDestElems =
82  (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
83  numSrcElemsPerDest;
84 
85  Operation *maskOp = mask.getDefiningOp();
87  // TODO: add support to `vector.splat`.
88  // Finding the mask creation operation.
89  while (maskOp &&
90  !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
91  maskOp)) {
92  if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
93  maskOp = extractOp.getVector().getDefiningOp();
94  extractOps.push_back(extractOp);
95  }
96  }
97 
98  if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
99  maskOp))
100  return failure();
101 
102  // Computing the "compressed" mask. All the emulation logic (i.e. computing
103  // new mask index) only happens on the last dimension of the vectors.
104  SmallVector<int64_t> maskShape(
105  cast<VectorType>(maskOp->getResultTypes()[0]).getShape());
106  maskShape.back() = numDestElems;
107  auto newMaskType = VectorType::get(maskShape, rewriter.getI1Type());
108  std::optional<Operation *> newMask =
110  .Case<vector::CreateMaskOp>(
111  [&](auto createMaskOp) -> std::optional<Operation *> {
112  OperandRange maskOperands = createMaskOp.getOperands();
113  // The `vector.create_mask` op creates a mask arrangement
114  // without any zeros at the front. Also, because
115  // `numFrontPadElems` is strictly smaller than
116  // `numSrcElemsPerDest`, the compressed mask generated by
117  // padding the original mask by `numFrontPadElems` will not
118  // have any zeros at the front as well.
119  AffineExpr s0;
120  bindSymbols(rewriter.getContext(), s0);
121  s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
122  OpFoldResult origIndex = getAsOpFoldResult(maskOperands.back());
124  rewriter, loc, s0, origIndex);
125  SmallVector<Value> newMaskOperands(maskOperands.drop_back());
126  newMaskOperands.push_back(
127  getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
128  return rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
129  newMaskOperands);
130  })
131  .Case<vector::ConstantMaskOp>(
132  [&](auto constantMaskOp) -> std::optional<Operation *> {
133  // Take the shape of mask, compress its trailing dimension:
134  SmallVector<int64_t> maskDimSizes(
135  constantMaskOp.getMaskDimSizes());
136  int64_t &maskIndex = maskDimSizes.back();
137  maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex,
138  numSrcElemsPerDest);
139  return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
140  maskDimSizes);
141  })
142  .Case<arith::ConstantOp>([&](auto constantOp)
143  -> std::optional<Operation *> {
144  // TODO: Support multiple dimensions.
145  if (maskShape.size() != 1)
146  return std::nullopt;
147  // Rearrange the original mask values to cover the whole potential
148  // loading region. For example, in the case of using byte-size for
149  // emulation, given the following mask:
150  //
151  // %mask = [0, 1, 0, 1, 0, 0]
152  //
153  // With front offset of 1, the mask will be padded 0s in the front
154  // and back so that:
155  // 1. It is aligned with the effective loading bits
156  // 2. Its length is multiple of `numSrcElemPerDest` (and the total
157  // coverage size is mulitiple of bytes). The new mask will be like
158  // this before compressing:
159  //
160  // %new_mask = [0, 0, 1, 0, 1, 0, 0, 0]
161  auto originalMask =
162  cast<DenseIntElementsAttr>(constantOp.getValue());
163  SmallVector<bool> paddedMaskValues(numFrontPadElems, false);
164  paddedMaskValues.append(originalMask.template value_begin<bool>(),
165  originalMask.template value_end<bool>());
166  paddedMaskValues.resize(numDestElems * numSrcElemsPerDest, false);
167 
168  // Compressing by combining every `numSrcElemsPerDest` elements:
169  SmallVector<bool> compressedMaskValues;
170  for (size_t i = 0; i < paddedMaskValues.size();
171  i += numSrcElemsPerDest) {
172  bool combinedValue = false;
173  for (int j = 0; j < numSrcElemsPerDest; ++j) {
174  combinedValue |= paddedMaskValues[i + j];
175  }
176  compressedMaskValues.push_back(combinedValue);
177  }
178  return rewriter.create<arith::ConstantOp>(
179  loc, DenseElementsAttr::get(newMaskType, compressedMaskValues));
180  });
181 
182  if (!newMask)
183  return failure();
184 
185  while (!extractOps.empty()) {
186  newMask = rewriter.create<vector::ExtractOp>(
187  loc, (*newMask)->getResults()[0], extractOps.back().getMixedPosition());
188  extractOps.pop_back();
189  }
190 
191  return *newMask;
192 }
193 
194 /// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
195 /// emitting `vector.extract_strided_slice`.
197  VectorType extractType, Value source,
198  int64_t frontOffset,
199  int64_t subvecSize) {
200  auto vectorType = cast<VectorType>(source.getType());
201  assert((vectorType.getRank() == 1 && extractType.getRank() == 1) &&
202  "expected 1-D source and destination types");
203  (void)vectorType;
204  assert(frontOffset + subvecSize <= vectorType.getNumElements() &&
205  "subvector out of bounds");
206 
207  // do not need extraction if the subvector size is the same as the source
208  if (vectorType.getNumElements() == subvecSize)
209  return source;
210 
211  auto offsets = rewriter.getI64ArrayAttr({frontOffset});
212  auto sizes = rewriter.getI64ArrayAttr({subvecSize});
213  auto strides = rewriter.getI64ArrayAttr({1});
214  return rewriter
215  .create<vector::ExtractStridedSliceOp>(loc, extractType, source, offsets,
216  sizes, strides)
217  ->getResult(0);
218 }
219 
220 /// Inserts 1-D subvector into a 1-D vector by overwriting the elements starting
221 /// at `offset`. it is a wrapper function for emitting
222 /// `vector.insert_strided_slice`.
224  Value src, Value dest, int64_t offset) {
225  [[maybe_unused]] auto srcType = cast<VectorType>(src.getType());
226  [[maybe_unused]] auto destType = cast<VectorType>(dest.getType());
227  assert(srcType.getRank() == 1 && destType.getRank() == 1 &&
228  "expected source and dest to be vector type");
229  auto offsets = rewriter.getI64ArrayAttr({offset});
230  auto strides = rewriter.getI64ArrayAttr({1});
231  return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
232  dest, offsets, strides);
233 }
234 
235 /// Extracts a 1-D subvector from a 1-D `source` vector, with index at `offset`
236 /// and size `numElementsToExtract`, and inserts into the `dest` vector. This
237 /// function emits multiple `vector.extract` and `vector.insert` ops, so only
238 /// use it when `offset` cannot be folded into a constant value.
240  TypedValue<VectorType> source,
241  Value dest, OpFoldResult offset,
242  int64_t numElementsToExtract) {
243  for (int i = 0; i < numElementsToExtract; ++i) {
244  Value extractLoc =
245  (i == 0) ? offset.dyn_cast<Value>()
246  : rewriter.create<arith::AddIOp>(
247  loc, rewriter.getIndexType(), offset.dyn_cast<Value>(),
248  rewriter.create<arith::ConstantIndexOp>(loc, i));
249  auto extractOp =
250  rewriter.create<vector::ExtractOp>(loc, source, extractLoc);
251  dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, i);
252  }
253  return dest;
254 }
255 
256 /// Inserts a 1-D subvector into a 1-D `dest` vector at index `destOffsetVar`.
258  TypedValue<VectorType> source,
259  Value dest, OpFoldResult destOffsetVar,
260  size_t length) {
261  assert(length > 0 && "length must be greater than 0");
262  Value destOffsetVal =
263  getValueOrCreateConstantIndexOp(rewriter, loc, destOffsetVar);
264  for (size_t i = 0; i < length; ++i) {
265  auto insertLoc = i == 0
266  ? destOffsetVal
267  : rewriter.create<arith::AddIOp>(
268  loc, rewriter.getIndexType(), destOffsetVal,
269  rewriter.create<arith::ConstantIndexOp>(loc, i));
270  auto extractOp = rewriter.create<vector::ExtractOp>(loc, source, i);
271  dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
272  }
273  return dest;
274 }
275 
276 /// Returns the op sequence for an emulated sub-byte data type vector load.
277 /// specifically, use `emulatedElemType` for loading a vector of `origElemType`.
278 /// The load location is given by `base` and `linearizedIndices`, and the
279 /// load size is given by `numEmulatedElementsToLoad`.
282  OpFoldResult linearizedIndices,
283  int64_t numEmultedElementsToLoad, Type origElemType,
284  Type emulatedElemType) {
285  auto scale = emulatedElemType.getIntOrFloatBitWidth() /
286  origElemType.getIntOrFloatBitWidth();
287  auto newLoad = rewriter.create<vector::LoadOp>(
288  loc, VectorType::get(numEmultedElementsToLoad, emulatedElemType), base,
289  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
290  return rewriter.create<vector::BitCastOp>(
291  loc, VectorType::get(numEmultedElementsToLoad * scale, origElemType),
292  newLoad);
293 }
294 
295 namespace {
296 
297 //===----------------------------------------------------------------------===//
298 // ConvertVectorStore
299 //===----------------------------------------------------------------------===//
300 
301 struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
303 
304  LogicalResult
305  matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
306  ConversionPatternRewriter &rewriter) const override {
307 
308  // See #115653
309  if (op.getValueToStore().getType().getRank() != 1)
310  return rewriter.notifyMatchFailure(op,
311  "only 1-D vectors are supported ATM");
312 
313  auto loc = op.getLoc();
314  auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
315  Type oldElementType = op.getValueToStore().getType().getElementType();
316  Type newElementType = convertedType.getElementType();
317  int srcBits = oldElementType.getIntOrFloatBitWidth();
318  int dstBits = newElementType.getIntOrFloatBitWidth();
319 
320  if (dstBits % srcBits != 0) {
321  return rewriter.notifyMatchFailure(
322  op, "only dstBits % srcBits == 0 supported");
323  }
324  int scale = dstBits / srcBits;
325 
326  // Adjust the number of elements to store when emulating narrow types.
327  // Here only the 1-D vector store is considered, and the N-D memref types
328  // should be linearized.
329  // For example, to emulate i4 to i8, the following op:
330  //
331  // vector.store %arg1, %0[%arg2, %arg3] : memref<4x8xi4>, vector<8xi4>
332  //
333  // can be replaced with
334  //
335  // %bitcast = vector.bitcast %arg1 : vector<8xi4> to vector<4xi8>
336  // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
337  // vector<4xi8>
338 
339  auto origElements = op.getValueToStore().getType().getNumElements();
340  if (origElements % scale != 0)
341  return failure();
342 
343  auto stridedMetadata =
344  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
345 
346  OpFoldResult linearizedIndices;
347  std::tie(std::ignore, linearizedIndices) =
349  rewriter, loc, srcBits, dstBits,
350  stridedMetadata.getConstifiedMixedOffset(),
351  stridedMetadata.getConstifiedMixedSizes(),
352  stridedMetadata.getConstifiedMixedStrides(),
353  getAsOpFoldResult(adaptor.getIndices()));
354 
355  auto numElements = origElements / scale;
356  auto bitCast = rewriter.create<vector::BitCastOp>(
357  loc, VectorType::get(numElements, newElementType),
358  op.getValueToStore());
359 
360  rewriter.replaceOpWithNewOp<vector::StoreOp>(
361  op, bitCast.getResult(), adaptor.getBase(),
362  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
363  return success();
364  }
365 };
366 
367 //===----------------------------------------------------------------------===//
368 // ConvertVectorMaskedStore
369 //===----------------------------------------------------------------------===//
370 
371 struct ConvertVectorMaskedStore final
372  : OpConversionPattern<vector::MaskedStoreOp> {
374 
375  LogicalResult
376  matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
377  ConversionPatternRewriter &rewriter) const override {
378 
379  // See #115653
380  if (op.getValueToStore().getType().getRank() != 1)
381  return rewriter.notifyMatchFailure(op,
382  "only 1-D vectors are supported ATM");
383 
384  auto loc = op.getLoc();
385  auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
386  Type oldElementType = op.getValueToStore().getType().getElementType();
387  Type newElementType = convertedType.getElementType();
388  int srcBits = oldElementType.getIntOrFloatBitWidth();
389  int dstBits = newElementType.getIntOrFloatBitWidth();
390 
391  if (dstBits % srcBits != 0) {
392  return rewriter.notifyMatchFailure(
393  op, "only dstBits % srcBits == 0 supported");
394  }
395 
396  int scale = dstBits / srcBits;
397  int origElements = op.getValueToStore().getType().getNumElements();
398  if (origElements % scale != 0)
399  return failure();
400 
401  auto stridedMetadata =
402  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
403  OpFoldResult linearizedIndicesOfr;
404  memref::LinearizedMemRefInfo linearizedInfo;
405  std::tie(linearizedInfo, linearizedIndicesOfr) =
407  rewriter, loc, srcBits, dstBits,
408  stridedMetadata.getConstifiedMixedOffset(),
409  stridedMetadata.getConstifiedMixedSizes(),
410  stridedMetadata.getConstifiedMixedStrides(),
411  getAsOpFoldResult(adaptor.getIndices()));
412  Value linearizedIndices =
413  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
414 
415  // Load the whole data and use arith.select to handle the corner cases.
416  //
417  // As an example, for this masked store of i4 values:
418  //
419  // vector.maskedstore %0[%c0, %c0], %mask, %val_to_store
420  //
421  // and given these input values:
422  //
423  // %mask = [0, 1, 1, 1, 1, 0, 0, 0] (8 * i1)
424  // %0[%c0, %c0] =
425  // [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4)
426  // %val_to_store =
427  // [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0] (8 * i4)
428  //
429  // we'll have the following i4 output:
430  //
431  // expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x7, 0x8]
432  //
433  // Emulating the above using i8 will give:
434  //
435  // %compressed_mask = [1, 1, 1, 0] (4 * i1)
436  // %maskedload = [0x12, 0x34, 0x56, 0x00] (4 * i8)
437  // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0] (8 * i4)
438  // %select_using_shifted_mask =
439  // [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x0, 0x0] (8 * i4)
440  // %packed_data = [0x1A, 0xBC, 0xD6, 0x00] (4 * i8)
441  //
442  // Using the compressed mask to store %packed_data results in expected
443  // output.
444  //
445  // FIXME: Make an example based on the comment above work (see #115460 for
446  // reproducer).
447  FailureOr<Operation *> newMask =
448  getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
449  if (failed(newMask))
450  return failure();
451 
452  auto numElements = (origElements + scale - 1) / scale;
453  auto newType = VectorType::get(numElements, newElementType);
454  auto passThru = rewriter.create<arith::ConstantOp>(
455  loc, newType, rewriter.getZeroAttr(newType));
456 
457  auto newLoad = rewriter.create<vector::MaskedLoadOp>(
458  loc, newType, adaptor.getBase(), linearizedIndices,
459  newMask.value()->getResult(0), passThru);
460 
461  auto newBitCastType = VectorType::get(numElements * scale, oldElementType);
462  Value valueToStore =
463  rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
464  valueToStore = rewriter.create<arith::SelectOp>(
465  loc, op.getMask(), op.getValueToStore(), valueToStore);
466  valueToStore =
467  rewriter.create<vector::BitCastOp>(loc, newType, valueToStore);
468 
469  rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
470  op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
471  valueToStore);
472  return success();
473  }
474 };
475 
476 //===----------------------------------------------------------------------===//
477 // ConvertVectorLoad
478 //===----------------------------------------------------------------------===//
479 
480 struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
482 
483  LogicalResult
484  matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
485  ConversionPatternRewriter &rewriter) const override {
486 
487  // See #115653
488  if (op.getVectorType().getRank() != 1)
489  return rewriter.notifyMatchFailure(op,
490  "only 1-D vectors are supported ATM");
491 
492  auto loc = op.getLoc();
493  auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
494  Type oldElementType = op.getType().getElementType();
495  Type newElementType = convertedType.getElementType();
496  int srcBits = oldElementType.getIntOrFloatBitWidth();
497  int dstBits = newElementType.getIntOrFloatBitWidth();
498 
499  if (dstBits % srcBits != 0) {
500  return rewriter.notifyMatchFailure(
501  op, "only dstBits % srcBits == 0 supported");
502  }
503  int scale = dstBits / srcBits;
504 
505  // Adjust the number of elements to load when emulating narrow types,
506  // and then cast back to the original type with vector.bitcast op.
507  // Here only the 1-D vector load is considered, and the N-D memref types
508  // should be linearized.
509  // For example, to emulate i4 to i8, the following op:
510  //
511  // %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
512  //
513  // can be replaced with
514  //
515  // %1 = vector.load %0[%linear_index] : memref<6xi8>, vector<2xi8>
516  // %2 = vector.bitcast %1 : vector<2xi8> to vector<4xi4>
517  //
518  // There are cases where the number of elements to load is not byte-aligned,
519  // for example:
520  //
521  // %1 = vector.load %0[%c1, %c0] : memref<3x3xi2>, vector<3xi2>
522  //
523  // we will have to load extra bytes and extract the exact slice in between.
524  //
525  // %1 = vector.load %0[%c2] : memref<3xi8>, vector<2xi8>
526  // %2 = vector.bitcast %1 : vector<2xi8> to vector<8xi2>
527  // %3 = vector.extract_strided_slice %1 {offsets = [2], sizes = [3], strides
528  // = [1]}
529  // : vector<8xi2> to vector<3xi2>
530  //
531  // TODO: Currently the extract_strided_slice's attributes must be known at
532  // compile time as they must be constants.
533 
534  auto origElements = op.getVectorType().getNumElements();
535  bool isUnalignedEmulation = origElements % scale != 0;
536 
537  auto stridedMetadata =
538  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
539 
540  OpFoldResult linearizedIndices;
541  memref::LinearizedMemRefInfo linearizedInfo;
542  std::tie(linearizedInfo, linearizedIndices) =
544  rewriter, loc, srcBits, dstBits,
545  stridedMetadata.getConstifiedMixedOffset(),
546  stridedMetadata.getConstifiedMixedSizes(),
547  stridedMetadata.getConstifiedMixedStrides(),
548  getAsOpFoldResult(adaptor.getIndices()));
549 
550  std::optional<int64_t> foldedIntraVectorOffset =
551  isUnalignedEmulation
552  ? getConstantIntValue(linearizedInfo.intraDataOffset)
553  : 0;
554 
555  // Always load enough elements which can cover the original elements.
556  int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
557  auto numElements =
558  llvm::divideCeil(maxintraDataOffset + origElements, scale);
559  Value result =
560  emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
561  numElements, oldElementType, newElementType);
562 
563  if (!foldedIntraVectorOffset) {
564  auto resultVector = rewriter.create<arith::ConstantOp>(
565  loc, op.getType(), rewriter.getZeroAttr(op.getType()));
567  rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
568  linearizedInfo.intraDataOffset, origElements);
569  } else if (isUnalignedEmulation) {
570  result =
571  staticallyExtractSubvector(rewriter, loc, op.getType(), result,
572  *foldedIntraVectorOffset, origElements);
573  }
574  rewriter.replaceOp(op, result);
575  return success();
576  }
577 };
578 
579 //===----------------------------------------------------------------------===//
580 // ConvertVectorMaskedLoad
581 //===----------------------------------------------------------------------===//
582 
583 struct ConvertVectorMaskedLoad final
584  : OpConversionPattern<vector::MaskedLoadOp> {
586 
587  LogicalResult
588  matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
589  ConversionPatternRewriter &rewriter) const override {
590  // See #115653
591  if (op.getVectorType().getRank() != 1)
592  return rewriter.notifyMatchFailure(op,
593  "only 1-D vectors are supported ATM");
594 
595  auto loc = op.getLoc();
596  auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
597  Type oldElementType = op.getType().getElementType();
598  Type newElementType = convertedType.getElementType();
599  int srcBits = oldElementType.getIntOrFloatBitWidth();
600  int dstBits = newElementType.getIntOrFloatBitWidth();
601 
602  if (dstBits % srcBits != 0) {
603  return rewriter.notifyMatchFailure(
604  op, "only dstBits % srcBits == 0 supported");
605  }
606  int scale = dstBits / srcBits;
607 
608  // Adjust the number of elements to load when emulating narrow types,
609  // and then cast back to the original type with vector.bitcast op.
610  // For example, to emulate i4 to i8, the following op:
611  //
612  // %mask = vector.constant_mask [3] : vector<6xi1>
613  // %1 = vector.maskedload %0[%c0, %c0], %mask, %pass_thru :
614  // memref<3x6xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
615  //
616  // can be replaced with
617  //
618  // %new_mask = vector.constant_mask [2] : vector<3xi1>
619  // %new_pass_thru = vector.bitcast %pass_thru :
620  // vector<6xi4> to vector<3xi8>
621  // %1 = vector.maskedload %0[%linear_index], %new_mask, %new_pass_thru :
622  // memref<9xi8>, vector<3xi1>, vector<3xi8> into vector<3xi8>
623  // %2 = vector.bitcast %1 : vector<3xi8> to vector<6xi4>
624  //
625  // Since we are effectively loading 16 bits (2xi8) from the memref with the
626  // new mask, while originally we only wanted to effectively load 12 bits
627  // (3xi4) from the memref, we need to set the second half of the last i8
628  // that was effectively loaded (i.e. the second i8) to %pass_thru.
629  //
630  // %3 = arith.select %mask, %2, %pass_thru : vector<6xi1>, vector<6xi4>
631  //
632  // Given these input values:
633  // %mask = [1, 1, 1, 0, 0, 0]
634  // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
635  // %pass_thru = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
636  //
637  // we'll have:
638  //
639  // expected output: [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
640  //
641  // %new_mask = [1, 1, 0]
642  // %new_pass_thru = [0x78, 0x9A, 0xBC]
643  // %1 = [0x12, 0x34, 0xBC]
644  // %2 = [0x1, 0x2, 0x3, 0x4, 0xB, 0xC]
645  // %3 = [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
646  //
647  // TODO: Currently, only the even number of elements loading is supported.
648  // To deal with the odd number of elements, one has to extract the
649  // subvector at the proper offset after bit-casting.
650  auto origType = op.getVectorType();
651  auto origElements = origType.getNumElements();
652  bool isUnalignedEmulation = origElements % scale != 0;
653 
654  auto stridedMetadata =
655  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
656  OpFoldResult linearizedIndices;
657  memref::LinearizedMemRefInfo linearizedInfo;
658  std::tie(linearizedInfo, linearizedIndices) =
660  rewriter, loc, srcBits, dstBits,
661  stridedMetadata.getConstifiedMixedOffset(),
662  stridedMetadata.getConstifiedMixedSizes(),
663  stridedMetadata.getConstifiedMixedStrides(),
664  getAsOpFoldResult(adaptor.getIndices()));
665 
666  std::optional<int64_t> foldedIntraVectorOffset =
667  isUnalignedEmulation
668  ? getConstantIntValue(linearizedInfo.intraDataOffset)
669  : 0;
670 
671  int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
672  FailureOr<Operation *> newMask = getCompressedMaskOp(
673  rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
674  if (failed(newMask))
675  return failure();
676 
677  Value passthru = op.getPassThru();
678 
679  auto numElements =
680  llvm::divideCeil(maxIntraDataOffset + origElements, scale);
681  auto loadType = VectorType::get(numElements, newElementType);
682  auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
683 
684  auto emptyVector = rewriter.create<arith::ConstantOp>(
685  loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
686  if (!foldedIntraVectorOffset) {
687  passthru = dynamicallyInsertSubVector(
688  rewriter, loc, dyn_cast<TypedValue<VectorType>>(passthru),
689  emptyVector, linearizedInfo.intraDataOffset, origElements);
690  } else if (isUnalignedEmulation) {
691  passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
692  *foldedIntraVectorOffset);
693  }
694  auto newPassThru =
695  rewriter.create<vector::BitCastOp>(loc, loadType, passthru);
696 
697  // Generating the new masked load.
698  auto newLoad = rewriter.create<vector::MaskedLoadOp>(
699  loc, loadType, adaptor.getBase(),
700  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
701  newMask.value()->getResult(0), newPassThru);
702 
703  // Setting the part that originally was not effectively loaded from memory
704  // to pass through.
705  auto bitCast =
706  rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
707 
708  Value mask = op.getMask();
709  auto newSelectMaskType =
710  VectorType::get(numElements * scale, rewriter.getI1Type());
711  // TODO: try to fold if op's mask is constant
712  auto emptyMask = rewriter.create<arith::ConstantOp>(
713  loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
714  if (!foldedIntraVectorOffset) {
716  rewriter, loc, dyn_cast<TypedValue<VectorType>>(mask), emptyMask,
717  linearizedInfo.intraDataOffset, origElements);
718  } else if (isUnalignedEmulation) {
719  mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
720  *foldedIntraVectorOffset);
721  }
722 
723  Value result =
724  rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
725  if (!foldedIntraVectorOffset) {
727  rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
728  op.getPassThru(), linearizedInfo.intraDataOffset, origElements);
729  } else if (isUnalignedEmulation) {
730  result =
731  staticallyExtractSubvector(rewriter, loc, op.getType(), result,
732  *foldedIntraVectorOffset, origElements);
733  }
734  rewriter.replaceOp(op, result);
735 
736  return success();
737  }
738 };
739 
740 //===----------------------------------------------------------------------===//
741 // ConvertVectorTransferRead
742 //===----------------------------------------------------------------------===//
743 
744 struct ConvertVectorTransferRead final
745  : OpConversionPattern<vector::TransferReadOp> {
747 
748  LogicalResult
749  matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
750  ConversionPatternRewriter &rewriter) const override {
751 
752  // See #115653
753  if (op.getVectorType().getRank() != 1)
754  return rewriter.notifyMatchFailure(op,
755  "only 1-D vectors are supported ATM");
756 
757  auto loc = op.getLoc();
758  auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
759  Type oldElementType = op.getType().getElementType();
760  Type newElementType = convertedType.getElementType();
761  int srcBits = oldElementType.getIntOrFloatBitWidth();
762  int dstBits = newElementType.getIntOrFloatBitWidth();
763 
764  if (dstBits % srcBits != 0) {
765  return rewriter.notifyMatchFailure(
766  op, "only dstBits % srcBits == 0 supported");
767  }
768  int scale = dstBits / srcBits;
769 
770  auto origElements = op.getVectorType().getNumElements();
771 
772  bool isUnalignedEmulation = origElements % scale != 0;
773 
774  auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
775  adaptor.getPadding());
776 
777  auto stridedMetadata =
778  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
779 
780  OpFoldResult linearizedIndices;
781  memref::LinearizedMemRefInfo linearizedInfo;
782  std::tie(linearizedInfo, linearizedIndices) =
784  rewriter, loc, srcBits, dstBits,
785  stridedMetadata.getConstifiedMixedOffset(),
786  stridedMetadata.getConstifiedMixedSizes(),
787  stridedMetadata.getConstifiedMixedStrides(),
788  getAsOpFoldResult(adaptor.getIndices()));
789 
790  std::optional<int64_t> foldedIntraVectorOffset =
791  isUnalignedEmulation
792  ? getConstantIntValue(linearizedInfo.intraDataOffset)
793  : 0;
794 
795  int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
796  auto numElements =
797  llvm::divideCeil(maxIntraDataOffset + origElements, scale);
798 
799  auto newRead = rewriter.create<vector::TransferReadOp>(
800  loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
801  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
802  newPadding);
803 
804  auto bitCast = rewriter.create<vector::BitCastOp>(
805  loc, VectorType::get(numElements * scale, oldElementType), newRead);
806 
807  Value result = bitCast->getResult(0);
808  if (!foldedIntraVectorOffset) {
809  auto zeros = rewriter.create<arith::ConstantOp>(
810  loc, op.getType(), rewriter.getZeroAttr(op.getType()));
811  result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
812  linearizedInfo.intraDataOffset,
813  origElements);
814  } else if (isUnalignedEmulation) {
815  result =
816  staticallyExtractSubvector(rewriter, loc, op.getType(), result,
817  *foldedIntraVectorOffset, origElements);
818  }
819  rewriter.replaceOp(op, result);
820 
821  return success();
822  }
823 };
824 } // end anonymous namespace
825 
826 //===----------------------------------------------------------------------===//
827 // RewriteBitCastOfTruncI
828 //===----------------------------------------------------------------------===//
829 
830 namespace {
831 
832 /// Helper struct to keep track of the provenance of a contiguous set of bits
833 /// in a source vector.
834 struct SourceElementRange {
835  /// The index of the source vector element that contributes bits to *this.
836  int64_t sourceElementIdx;
837  /// The range of bits in the source vector element that contribute to *this.
838  int64_t sourceBitBegin;
839  int64_t sourceBitEnd;
840 };
841 
842 struct SourceElementRangeList : public SmallVector<SourceElementRange> {
843  /// Given the index of a SourceElementRange in the SourceElementRangeList,
844  /// compute the amount of bits that need to be shifted to the left to get the
845  /// bits in their final location. This shift amount is simply the sum of the
846  /// bits *before* `shuffleIdx` (i.e. the bits of `shuffleIdx = 0` are always
847  /// the LSBs, the bits of `shuffleIdx = ` come next, etc).
848  int64_t computeLeftShiftAmount(int64_t shuffleIdx) const {
849  int64_t res = 0;
850  for (int64_t i = 0; i < shuffleIdx; ++i)
851  res += (*this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
852  return res;
853  }
854 };
855 
856 /// Helper struct to enumerate the source elements and bit ranges that are
857 /// involved in a bitcast operation.
858 /// This allows rewriting a vector.bitcast into shuffles and bitwise ops for
859 /// any 1-D vector shape and any source/target bitwidths.
860 /// This creates and holds a mapping of the form:
861 /// [dstVectorElementJ] ==
862 /// [ {srcVectorElementX, bitRange}, {srcVectorElementY, bitRange}, ... ]
863 /// E.g. `vector.bitcast ... : vector<1xi24> to vector<3xi8>` is decomposed as:
864 /// [0] = {0, [0-8)}
865 /// [1] = {0, [8-16)}
866 /// [2] = {0, [16-24)}
867 /// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as:
868 /// [0] = {0, [0, 10)}, {1, [0, 5)}
869 /// [1] = {1, [5, 10)}, {2, [0, 10)}
870 struct BitCastBitsEnumerator {
871  BitCastBitsEnumerator(VectorType sourceVectorType,
872  VectorType targetVectorType);
873 
874  int64_t getMaxNumberOfEntries() {
875  int64_t numVectors = 0;
876  for (const auto &l : sourceElementRanges)
877  numVectors = std::max(numVectors, (int64_t)l.size());
878  return numVectors;
879  }
880 
881  VectorType sourceVectorType;
882  VectorType targetVectorType;
883  SmallVector<SourceElementRangeList> sourceElementRanges;
884 };
885 
886 /// Rewrite vector.bitcast to a sequence of shuffles and bitwise ops that take
887 /// advantage of high-level information to avoid leaving LLVM to scramble with
888 /// peephole optimizations.
889 /// BitCastBitsEnumerator encodes for each element of the target vector the
890 /// provenance of the bits in the source vector. We can "transpose" this
891 /// information to build a sequence of shuffles and bitwise ops that will
892 /// produce the desired result.
893 //
894 /// Consider the following motivating example:
895 /// ```
896 /// %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
897 /// ```
898 //
899 /// BitCastBitsEnumerator contains the following information:
900 /// ```
901 /// { 0: b@[0..5) lshl: 0}{ 1: b@[0..3) lshl: 5}
902 /// { 1: b@[3..5) lshl: 0}{ 2: b@[0..5) lshl: 2}{ 3: b@[0..1) lshl: 7}
903 /// { 3: b@[1..5) lshl: 0}{ 4: b@[0..4) lshl: 4}
904 /// { 4: b@[4..5) lshl: 0}{ 5: b@[0..5) lshl: 1}{ 6: b@[0..2) lshl: 6}
905 /// { 6: b@[2..5) lshl: 0}{ 7: b@[0..5) lshl: 3}
906 /// { 8: b@[0..5) lshl: 0}{ 9: b@[0..3) lshl: 5}
907 /// { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7}
908 /// {11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4}
909 /// {12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6}
910 /// {14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
911 /// {16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
912 /// {17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
913 /// {19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
914 /// {20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1}{22: b@[0..2) lshl: 6}
915 /// {22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3}
916 /// {24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5}
917 /// {25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7}
918 /// {27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
919 /// {28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
920 /// {30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3}
921 /// ```
922 ///
923 /// In the above, each row represents one target vector element and each
924 /// column represents one bit contribution from a source vector element.
925 /// The algorithm creates vector.shuffle operations (in this case there are 3
926 /// shuffles (i.e. the max number of columns in BitCastBitsEnumerator). The
927 /// algorithm populates the bits as follows:
928 /// ```
929 /// src bits 0 ...
930 /// 1st shuffle |xxxxx |xx |...
931 /// 2nd shuffle | xxx| xxxxx |...
932 /// 3rd shuffle | | x|...
933 /// ```
934 //
935 /// The algorithm proceeds as follows:
936 /// 1. for each vector.shuffle, collect the source vectors that participate in
937 /// this shuffle. One source vector per target element of the resulting
938 /// vector.shuffle. If there is no source element contributing bits for the
939 /// current vector.shuffle, take 0 (i.e. row 0 in the above example has only
940 /// 2 columns).
941 /// 2. represent the bitrange in the source vector as a mask. If there is no
942 /// source element contributing bits for the current vector.shuffle, take 0.
943 /// 3. shift right by the proper amount to align the source bitrange at
944 /// position 0. This is exactly the low end of the bitrange. For instance,
945 /// the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
946 /// shift right by 3 to get the bits contributed by the source element #1
947 /// into position 0.
948 /// 4. shift left by the proper amount to to align to the desired position in
949 /// the result element vector. For instance, the contribution of the second
950 /// source element for the first row needs to be shifted by `5` to form the
951 /// first i8 result element.
952 ///
953 /// Eventually, we end up building the sequence
954 /// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
955 /// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
956 /// bits extracted from the source vector (i.e. the `shuffle -> and` part).
957 struct BitCastRewriter {
958  /// Helper metadata struct to hold the static quantities for the rewrite.
959  struct Metadata {
960  SmallVector<int64_t> shuffles;
961  SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
962  };
963 
964  BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
965 
966  /// Verify that general preconditions for the rewrite are met.
967  LogicalResult commonPrecondition(PatternRewriter &rewriter,
968  VectorType preconditionType, Operation *op);
969 
970  /// Precompute the metadata for the rewrite.
972  precomputeMetadata(IntegerType shuffledElementType);
973 
974  /// Rewrite one step of the sequence:
975  /// `(shuffle -> and -> shiftright -> shiftleft -> or)`.
976  Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
977  Value initialValue, Value runningResult,
978  const BitCastRewriter::Metadata &metadata);
979 
980 private:
981  /// Underlying enumerator that encodes the provenance of the bits in the each
982  /// element of the result vector.
983  BitCastBitsEnumerator enumerator;
984 };
985 
986 } // namespace
987 
988 [[maybe_unused]] static raw_ostream &
989 operator<<(raw_ostream &os, const SmallVector<SourceElementRangeList> &vec) {
990  for (const auto &l : vec) {
991  for (auto it : llvm::enumerate(l)) {
992  os << "{ " << it.value().sourceElementIdx << ": b@["
993  << it.value().sourceBitBegin << ".." << it.value().sourceBitEnd
994  << ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } ";
995  }
996  os << "\n";
997  }
998  return os;
999 }
1000 
1001 BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
1002  VectorType targetVectorType)
1003  : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
1004 
1005  assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
1006  "requires -D non-scalable vector type");
1007  assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
1008  "requires -D non-scalable vector type");
1009  int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
1010  int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
1011  LDBG("sourceVectorType: " << sourceVectorType);
1012 
1013  int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
1014  int64_t mostMinorTargetDim = targetVectorType.getShape().back();
1015  LDBG("targetVectorType: " << targetVectorType);
1016 
1017  int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
1018  (void)mostMinorSourceDim;
1019  assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
1020  "source and target bitwidths must match");
1021 
1022  // Prepopulate one source element range per target element.
1023  sourceElementRanges = SmallVector<SourceElementRangeList>(mostMinorTargetDim);
1024  for (int64_t resultBit = 0; resultBit < bitwidth;) {
1025  int64_t resultElement = resultBit / targetBitWidth;
1026  int64_t resultBitInElement = resultBit % targetBitWidth;
1027  int64_t sourceElementIdx = resultBit / sourceBitWidth;
1028  int64_t sourceBitInElement = resultBit % sourceBitWidth;
1029  int64_t step = std::min(sourceBitWidth - sourceBitInElement,
1030  targetBitWidth - resultBitInElement);
1031  sourceElementRanges[resultElement].push_back(
1032  {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
1033  resultBit += step;
1034  }
1035 }
1036 
1037 BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
1038  VectorType targetVectorType)
1039  : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
1040  LDBG("\n" << enumerator.sourceElementRanges);
1041 }
1042 
1043 /// Verify that the precondition type meets the common preconditions for any
1044 /// conversion.
1045 static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
1046  VectorType preconditionType,
1047  Operation *op) {
1048  if (!preconditionType || preconditionType.isScalable())
1049  return rewriter.notifyMatchFailure(op, "scalable vector");
1050 
1051  // TODO: consider relaxing this restriction in the future if we find ways
1052  // to really work with subbyte elements across the MLIR/LLVM boundary.
1053  unsigned bitwidth = preconditionType.getElementTypeBitWidth();
1054  if (bitwidth % 8 != 0)
1055  return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
1056 
1057  return success();
1058 }
1059 
1060 LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
1061  VectorType preconditionType,
1062  Operation *op) {
1063  if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
1064  return rewriter.notifyMatchFailure(op, "types are not vector");
1065 
1066  if (!preconditionType || preconditionType.getRank() != 1)
1067  return rewriter.notifyMatchFailure(op, "unsupported >1-D vector");
1068 
1069  return commonConversionPrecondition(rewriter, preconditionType, op);
1070 }
1071 
1072 /// Verify that source and destination element types meet the precondition for
1073 /// the supported aligned conversion cases. Alignment means that the either the
1074 /// source element type is multiple of the destination element type or the other
1075 /// way around.
1076 ///
1077 /// NOTE: This method assumes that common conversion preconditions are met.
1078 static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
1079  VectorType srcType,
1080  VectorType dstType,
1081  Operation *op) {
1082  if (!srcType || !dstType)
1083  return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
1084  unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
1085  unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
1086 
1087  // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
1088  if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
1089  (dstElemBitwidth % srcElemBitwidth) != 0)
1090  return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
1091 
1092  if ((srcType.getShape().back() % 2) != 0)
1093  return rewriter.notifyMatchFailure(
1094  op, "Not an even number of i4 elements in trailing dim");
1095 
1096  return success();
1097 }
1098 
1100 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
1102  for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
1103  shuffleIdx < e; ++shuffleIdx) {
1104  SmallVector<int64_t> shuffles;
1105  SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1106 
1107  // Create the attribute quantities for the shuffle / mask / shift ops.
1108  for (auto &srcEltRangeList : enumerator.sourceElementRanges) {
1109  int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
1110  ? srcEltRangeList[shuffleIdx].sourceElementIdx
1111  : 0;
1112  shuffles.push_back(sourceElement);
1113 
1114  int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
1115  ? srcEltRangeList[shuffleIdx].sourceBitBegin
1116  : 0;
1117  int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
1118  ? srcEltRangeList[shuffleIdx].sourceBitEnd
1119  : 0;
1120  IntegerAttr mask = IntegerAttr::get(
1121  shuffledElementType,
1122  llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
1123  bitLo, bitHi));
1124  masks.push_back(mask);
1125 
1126  int64_t shiftRight = bitLo;
1127  shiftRightAmounts.push_back(
1128  IntegerAttr::get(shuffledElementType, shiftRight));
1129 
1130  int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
1131  shiftLeftAmounts.push_back(
1132  IntegerAttr::get(shuffledElementType, shiftLeft));
1133  }
1134 
1135  result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
1136  }
1137  return result;
1138 }
1139 
1140 Value BitCastRewriter::genericRewriteStep(
1141  PatternRewriter &rewriter, Location loc, Value initialValue,
1142  Value runningResult, const BitCastRewriter::Metadata &metadata) {
1143  // Create vector.shuffle from the metadata.
1144  auto shuffleOp = rewriter.create<vector::ShuffleOp>(
1145  loc, initialValue, initialValue, metadata.shuffles);
1146 
1147  // Intersect with the mask.
1148  VectorType shuffledVectorType = shuffleOp.getResultVectorType();
1149  auto constOp = rewriter.create<arith::ConstantOp>(
1150  loc, DenseElementsAttr::get(shuffledVectorType, metadata.masks));
1151  Value andValue = rewriter.create<arith::AndIOp>(loc, shuffleOp, constOp);
1152 
1153  // Align right on 0.
1154  auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
1155  loc,
1156  DenseElementsAttr::get(shuffledVectorType, metadata.shiftRightAmounts));
1157  Value shiftedRight =
1158  rewriter.create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
1159 
1160  // Shift bits left into their final position.
1161  auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
1162  loc,
1163  DenseElementsAttr::get(shuffledVectorType, metadata.shiftLeftAmounts));
1164  Value shiftedLeft =
1165  rewriter.create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
1166 
1167  runningResult =
1168  runningResult
1169  ? rewriter.create<arith::OrIOp>(loc, runningResult, shiftedLeft)
1170  : shiftedLeft;
1171 
1172  return runningResult;
1173 }
1174 
1175 /// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
1176 /// bitwise ops that take advantage of high-level information to avoid leaving
1177 /// LLVM to scramble with peephole optimizations.
1179  Value srcValue) {
1180  VectorType srcVecType = cast<VectorType>(srcValue.getType());
1181  assert(srcVecType.getElementType().isSignlessInteger(4) &&
1182  "Expected i4 type");
1183 
1184  // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1185  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
1186  constexpr int64_t i4Toi8BitwidthFactor = 2;
1187  i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
1188  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
1189  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
1190 
1191  // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
1192  // byte are place in one vector and the high i4 elements in another vector.
1193  constexpr int8_t bitsToShift = 4;
1194  auto shiftValues = rewriter.create<arith::ConstantOp>(
1195  loc, DenseElementsAttr::get(i8VecType, bitsToShift));
1196  Value shl = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues);
1197  Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
1198  Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
1199 
1200  // 3. Interleave low and high i8 elements.
1201  return rewriter.create<vector::InterleaveOp>(loc, low, high);
1202 }
1203 
1204 /// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
1205 /// bitwise ops that take advantage of high-level information to avoid leaving
1206 /// LLVM to scramble with peephole optimizations.
1208  Value srcValue) {
1209  VectorType srcVecType = cast<VectorType>(srcValue.getType());
1210  assert(srcVecType.getElementType().isSignlessInteger(4) &&
1211  "Expected i4 type");
1212 
1213  // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1214  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
1215  constexpr int64_t i4Toi8BitwidthFactor = 2;
1216  i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
1217  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
1218  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
1219 
1220  // 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
1221  // byte are placed in one vector and the high i4 elements in another vector.
1222  constexpr uint8_t lowBitsMask = 15; // Equivalent to [00001111] bit mask
1223  auto lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
1224  loc, DenseElementsAttr::get(i8VecType, lowBitsMask));
1225  Value low = rewriter.create<arith::AndIOp>(loc, i8VecType, i8Vector,
1226  lowBitsMaskValues);
1227  constexpr int8_t highBitsToShift = 4;
1228  auto highShiftValues = rewriter.create<arith::ConstantOp>(
1229  loc, DenseElementsAttr::get(i8VecType, highBitsToShift));
1230  Value high = rewriter.create<arith::ShRUIOp>(loc, i8Vector, highShiftValues);
1231 
1232  // 3. Interleave low and high i8 elements.
1233  return rewriter.create<vector::InterleaveOp>(loc, low, high);
1234 }
1235 
1236 /// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
1237 /// ops that take advantage of high-level information to avoid leaving LLVM to
1238 /// scramble with peephole optimizations.
1240  Value srcValue) {
1241  VectorType srcVecType = cast<VectorType>(srcValue.getType());
1242  assert(srcVecType.getElementType().isSignlessInteger(8) &&
1243  "Expected i8 type");
1244 
1245  // 1. De-interleave low and high i8 elements.
1246  auto deinterleaveOp = rewriter.create<vector::DeinterleaveOp>(loc, srcValue);
1247 
1248  // 2. Zero out the upper side of each low i8 element.
1249  constexpr int8_t i8LowBitMask = 0x0F;
1250  VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
1251  Value zeroOutMask = rewriter.create<arith::ConstantOp>(
1252  loc, DenseElementsAttr::get(deinterI8VecType, i8LowBitMask));
1253  Value zeroOutLow = rewriter.create<arith::AndIOp>(
1254  loc, deinterleaveOp.getRes1(), zeroOutMask);
1255 
1256  // 3. Move high i4 values to upper side of the byte.
1257  constexpr int8_t bitsToShift = 4;
1258  auto shiftValues = rewriter.create<arith::ConstantOp>(
1259  loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
1260  Value shlHigh = rewriter.create<arith::ShLIOp>(loc, deinterleaveOp.getRes2(),
1261  shiftValues);
1262 
1263  // 4. Merge high and low i4 values.
1264  auto mergedHiLowOp = rewriter.create<arith::OrIOp>(loc, zeroOutLow, shlHigh);
1265 
1266  // 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
1267  auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());
1268  return rewriter.create<vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
1269 }
1270 
1271 namespace {
1272 /// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
1273 /// advantage of high-level information to avoid leaving LLVM to scramble with
1274 /// peephole optimizations.
1275 struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
1277 
1278  LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
1279  PatternRewriter &rewriter) const override {
1280  // The source must be a trunc op.
1281  auto truncOp =
1282  bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
1283  if (!truncOp)
1284  return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
1285 
1286  // Set up the BitCastRewriter and verify the precondition.
1287  VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1288  VectorType targetVectorType = bitCastOp.getResultVectorType();
1289  BitCastRewriter bcr(sourceVectorType, targetVectorType);
1290  if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
1291  return failure();
1292 
1293  // Perform the rewrite.
1294  Value truncValue = truncOp.getIn();
1295  auto shuffledElementType =
1296  cast<IntegerType>(getElementTypeOrSelf(truncValue.getType()));
1297  Value runningResult;
1298  for (const BitCastRewriter ::Metadata &metadata :
1299  bcr.precomputeMetadata(shuffledElementType)) {
1300  runningResult = bcr.genericRewriteStep(
1301  rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
1302  }
1303 
1304  // Finalize the rewrite.
1305  bool narrowing = targetVectorType.getElementTypeBitWidth() <=
1306  shuffledElementType.getIntOrFloatBitWidth();
1307  if (narrowing) {
1308  if (runningResult.getType() == bitCastOp.getResultVectorType()) {
1309  rewriter.replaceOp(bitCastOp, runningResult);
1310  } else {
1311  rewriter.replaceOpWithNewOp<arith::TruncIOp>(
1312  bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1313  }
1314  } else {
1315  if (runningResult.getType() == bitCastOp.getResultVectorType()) {
1316  rewriter.replaceOp(bitCastOp, runningResult);
1317  } else {
1318  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
1319  bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1320  }
1321  }
1322 
1323  return success();
1324  }
1325 };
1326 } // namespace
1327 
1328 //===----------------------------------------------------------------------===//
1329 // RewriteExtOfBitCast
1330 //===----------------------------------------------------------------------===//
1331 
1332 namespace {
1333 /// Rewrite ext{s,u}i(bitcast) to a sequence of shuffles and bitwise ops that
1334 /// take advantage of high-level information to avoid leaving LLVM to scramble
1335 /// with peephole optimizations.
1336 template <typename ExtOpType>
1337 struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
1339 
1340  RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
1341  : OpRewritePattern<ExtOpType>(context, benefit) {}
1342 
1343  LogicalResult matchAndRewrite(ExtOpType extOp,
1344  PatternRewriter &rewriter) const override {
1345  // The source must be a bitcast op.
1346  auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
1347  if (!bitCastOp)
1348  return rewriter.notifyMatchFailure(extOp, "not a bitcast source");
1349 
1350  // Set up the BitCastRewriter and verify the precondition.
1351  VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1352  VectorType targetVectorType = bitCastOp.getResultVectorType();
1353  BitCastRewriter bcr(sourceVectorType, targetVectorType);
1354  if (failed(bcr.commonPrecondition(
1355  rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
1356  return failure();
1357 
1358  // Perform the rewrite.
1359  Value runningResult;
1360  Value sourceValue = bitCastOp.getSource();
1361  auto shuffledElementType =
1362  cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
1363  for (const BitCastRewriter::Metadata &metadata :
1364  bcr.precomputeMetadata(shuffledElementType)) {
1365  runningResult = bcr.genericRewriteStep(
1366  rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
1367  }
1368 
1369  // Finalize the rewrite.
1370  bool narrowing =
1371  cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
1372  shuffledElementType.getIntOrFloatBitWidth();
1373  if (narrowing) {
1374  rewriter.replaceOpWithNewOp<arith::TruncIOp>(
1375  extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1376  } else {
1377  rewriter.replaceOpWithNewOp<ExtOpType>(
1378  extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1379  }
1380 
1381  return success();
1382  }
1383 };
1384 
1385 /// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
1386 /// bitwise ops that take advantage of high-level information to avoid leaving
1387 /// LLVM to scramble with peephole optimizations. Templated to choose between
1388 /// signed and unsigned conversions.
1389 ///
1390 /// For example (signed):
1391 /// arith.extsi %in : vector<8xi4> to vector<8xi32>
1392 /// is rewriten as
1393 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1394 /// %1 = arith.shli %0, 4 : vector<4xi8>
1395 /// %2 = arith.shrsi %1, 4 : vector<4xi8>
1396 /// %3 = arith.shrsi %0, 4 : vector<4xi8>
1397 /// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1398 /// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
1399 ///
1400 /// arith.sitofp %in : vector<8xi4> to vector<8xf32>
1401 /// is rewriten as
1402 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1403 /// %1 = arith.shli %0, 4 : vector<4xi8>
1404 /// %2 = arith.shrsi %1, 4 : vector<4xi8>
1405 /// %3 = arith.shrsi %0, 4 : vector<4xi8>
1406 /// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1407 /// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
1408 ///
1409 /// Example (unsigned):
1410 /// arith.extui %in : vector<8xi4> to vector<8xi32>
1411 /// is rewritten as
1412 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1413 /// %1 = arith.andi %0, 15 : vector<4xi8>
1414 /// %2 = arith.shrui %0, 4 : vector<4xi8>
1415 /// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
1416 /// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
1417 ///
1418 template <typename ConversionOpType, bool isSigned>
1419 struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
1421 
1422  LogicalResult matchAndRewrite(ConversionOpType conversionOp,
1423  PatternRewriter &rewriter) const override {
1424  // Verify the preconditions.
1425  Value srcValue = conversionOp.getIn();
1426  auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
1427  auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
1428 
1429  if (failed(
1430  commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
1431  return failure();
1432 
1433  // Check general alignment preconditions.
1434  if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
1435  conversionOp)))
1436  return failure();
1437 
1438  // Perform the rewrite.
1439  Value subByteExt;
1440  if (isSigned) {
1441  subByteExt =
1442  rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
1443  } else {
1444  subByteExt =
1445  rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
1446  }
1447 
1448  // Finalize the rewrite.
1449  rewriter.replaceOpWithNewOp<ConversionOpType>(
1450  conversionOp, conversionOp.getType(), subByteExt);
1451  return success();
1452  }
1453 };
1454 
1455 /// Rewrite the i8 -> i4 part of any truncation into a deinterleave and
1456 /// bitwise ops that take advantage of high-level information to avoid leaving
1457 /// LLVM to scramble with peephole optimizations.
1458 ///
1459 /// For example:
1460 /// arith.trunci %in : vector<8xi32> to vector<8xi4>
1461 /// is rewriten as
1462 ///
1463 /// %cst = arith.constant dense<15> : vector<4xi8>
1464 /// %cst_0 = arith.constant dense<4> : vector<4xi8>
1465 /// %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
1466 /// %2 = arith.andi %0, %cst : vector<4xi8>
1467 /// %3 = arith.shli %1, %cst_0 : vector<4xi8>
1468 /// %4 = arith.ori %2, %3 : vector<4xi8>
1469 /// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
1470 ///
1471 struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
1473 
1474  LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
1475  PatternRewriter &rewriter) const override {
1476  // Verify the preconditions.
1477  Value srcValue = truncOp.getIn();
1478  auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
1479  auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
1480  if (!srcVecType || !dstVecType)
1481  return failure();
1482 
1483  if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
1484  return failure();
1485 
1486  // Check general alignment preconditions. We invert the src/dst type order
1487  // to reuse the existing precondition logic.
1488  if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
1489  truncOp)))
1490  return failure();
1491 
1492  // Create a new iX -> i8 truncation op.
1493  Location loc = truncOp.getLoc();
1494  auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
1495  Value i8TruncVal =
1496  rewriter.create<arith::TruncIOp>(loc, i8VecType, srcValue);
1497 
1498  // Rewrite the i8 -> i4 truncation part.
1499  Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal);
1500 
1501  // Finalize the rewrite.
1502  rewriter.replaceOp(truncOp, subByteTrunc);
1503  return success();
1504  }
1505 };
1506 
1507 /// Rewrite a sub-byte vector transpose into a sequence of instructions that
1508 /// perform the transpose on wider (byte) element types.
1509 /// For example:
1510 /// %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
1511 ///
1512 /// is rewritten as:
1513 ///
1514 /// %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
1515 /// %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
1516 /// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4>
1517 ///
1518 struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
1520 
1521  RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
1522  : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
1523 
1524  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
1525  PatternRewriter &rewriter) const override {
1526  // Precondition: sub-byte integer transpose.
1527  constexpr unsigned minNativeBitwidth = 8;
1528  VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
1529  if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
1530  srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
1531  return rewriter.notifyMatchFailure(transposeOp,
1532  "not a sub-byte transpose");
1533  }
1534 
1535  // Perform the rewrite.
1536  Location loc = transposeOp.getLoc();
1537  // Signed/unsigned interpretation shouldn't matter here as we are just
1538  // transposing the elements and truncating them back to the original size.
1539  // TODO: Use unsigned extension (more efficient) when emulation or backend
1540  // support is available.
1541  auto srcNativeVecType = srcSubByteVecType.cloneWith(
1542  std::nullopt, rewriter.getIntegerType(minNativeBitwidth));
1543  Value extOp = rewriter.create<arith::ExtSIOp>(loc, srcNativeVecType,
1544  transposeOp.getVector());
1545  Value newTranspose = rewriter.create<vector::TransposeOp>(
1546  loc, extOp, transposeOp.getPermutation());
1547  VectorType dstSubByteVecType = transposeOp.getResultVectorType();
1548  rewriter.replaceOpWithNewOp<arith::TruncIOp>(transposeOp, dstSubByteVecType,
1549  newTranspose);
1550  return success();
1551  }
1552 };
1553 
1554 } // namespace
1555 
1556 //===----------------------------------------------------------------------===//
1557 // Public Interface Definition
1558 //===----------------------------------------------------------------------===//
1559 
1561  const arith::NarrowTypeEmulationConverter &typeConverter,
1563 
1564  // Populate `vector.*` conversion patterns.
1565  patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
1566  ConvertVectorMaskedStore, ConvertVectorTransferRead>(
1567  typeConverter, patterns.getContext());
1568 }
1569 
1572  patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
1573  RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
1574  benefit);
1575 
1576  // Patterns for aligned cases. We set higher priority as they are expected to
1577  // generate better performance for aligned cases.
1578  patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, /*isSigned=*/true>,
1579  RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
1580  RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
1581  benefit.getBenefit() + 1);
1582  patterns
1583  .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>,
1584  RewriteAlignedSubByteIntExt<arith::UIToFPOp, /*isSigned=*/false>>(
1585  patterns.getContext(), benefit.getBenefit() + 1);
1586 }
1587 
1590  patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
1591 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, TypedValue< VectorType > source, Value dest, OpFoldResult offset, int64_t numElementsToExtract)
Extracts a 1-D subvector from a 1-D source vector, with index at offset and size numElementsToExtract...
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops that take advantage of ...
static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i4 -> i8 signed extension into a sequence of shuffles and bitwise ops that take advantage...
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, Value src, Value dest, int64_t offset)
Inserts 1-D subvector into a 1-D vector by overwriting the elements starting at offset.
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, VectorType extractType, Value source, int64_t frontOffset, int64_t subvecSize)
Extracts 1-D subvector from a 1-D vector.
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int numSrcElems, int numSrcElemsPerDest, int numFrontPadElems=0)
Returns a compressed mask for the emulated vector.
static TypedValue< VectorType > emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, OpFoldResult linearizedIndices, int64_t numEmultedElementsToLoad, Type origElemType, Type emulatedElemType)
Returns the op sequence for an emulated sub-byte data type vector load.
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, TypedValue< VectorType > source, Value dest, OpFoldResult destOffsetVar, size_t length)
Inserts a 1-D subvector into a 1-D dest vector at index destOffsetVar.
#define LDBG(X)
static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and bitwise ops that take advanta...
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType srcType, VectorType dstType, Operation *op)
Verify that source and destination element types meet the precondition for the supported aligned conv...
Base type for affine expression.
Definition: AffineExpr.h:68
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
IntegerType getI4Type()
Definition: Builders.cpp:101
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:97
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:321
IndexType getIndexType()
Definition: Builders.cpp:95
IntegerType getI8Type()
Definition: Builders.cpp:103
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:216
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_type_range getResultTypes()
Definition: Operation.h:428
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Converts narrow integer or float types that are not supported by the target hardware to wider types.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1194
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Definition: MemRefUtils.cpp:52
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for rewriting vector operations over narrow types with ops over wider types.
void populateVectorNarrowTypeEmulationPatterns(const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating vector operations over narrow types with ops over wider types.
void populateVectorTransposeNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for emulating a sub-byte vector transpose.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:329
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
Definition: MemRefUtils.h:50
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.