MLIR  20.0.0git
VectorEmulateNarrowType.cpp
Go to the documentation of this file.
1 //===- VectorEmulateNarrowType.cpp - Narrow type emulation ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to emulate
10 // narrow types that are not supported by the target hardware, e.g. i4, using
11 // wider types, e.g. i8.
12 //
13 /// Currently, only power-of-two integer types are supported. These are
14 /// converted to wider integers that are either 8 bits wide or wider.
15 ///
16 /// TODO: Support for non-powers-of-two.
17 //===----------------------------------------------------------------------===//
18 
29 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/IR/OpDefinition.h"
31 #include "mlir/IR/TypeUtilities.h"
32 #include "mlir/IR/Value.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/MathExtras.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include <cstdint>
39 #include <optional>
40 
41 using namespace mlir;
42 
43 #define DEBUG_TYPE "vector-narrow-type-emulation"
44 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
45 #define DBGSNL() (llvm::dbgs() << "\n")
46 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
47 
48 /// Returns a compressed mask for the emulated vector. For example, when
49 /// emulating an eight-element `i8` vector with `i32` (i.e. when the source
50 /// elements span two dest elements), this method compresses `vector<8xi1>`
51 /// into `vector<2xi1>`.
52 ///
53 /// The compressed/output mask value is set iff any mask in the corresponding
54 /// `numSrcElemsPerDest` range of uncompressed/input masks is set. E.g., if
55 /// `numSrcElemsPerDest` equals to 2, and `numFrontPadElems` equals to 1, the
56 /// following mask:
57 ///
58 /// %mask = [1, 1, 0, 0, 0, 0]
59 ///
60 /// will first be padded in the front with `numFrontPadElems` zeros, and zeros
61 /// will be added in the back to make the number of elements a multiple of
62 /// `numSrcElemsPerDest` (for easier computation). The resulting mask will be:
63 ///
64 /// %mask = [0, 1, 1, 0, 0, 0, 0, 0]
65 ///
66 /// then it will return the following new compressed mask:
67 ///
68 /// %mask = [1, 1, 0, 0]
69 ///
70 /// NOTE: `numFrontPadElems` is assumed to be strictly smaller than
71 /// `numSrcElemsPerDest`.
72 static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
73  Location loc, Value mask,
74  int numSrcElems,
75  int numSrcElemsPerDest,
76  int numFrontPadElems = 0) {
77 
78  assert(numFrontPadElems < numSrcElemsPerDest &&
79  "numFrontPadElems must be less than numSrcElemsPerDest");
80 
81  auto numDestElems =
82  (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
83  numSrcElemsPerDest;
84 
85  Operation *maskOp = mask.getDefiningOp();
87  // TODO: add support to `vector.splat`.
88  // Finding the mask creation operation.
89  while (maskOp &&
90  !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
91  maskOp)) {
92  if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
93  maskOp = extractOp.getVector().getDefiningOp();
94  extractOps.push_back(extractOp);
95  }
96  }
97 
98  if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
99  maskOp))
100  return failure();
101 
102  // Computing the "compressed" mask. All the emulation logic (i.e. computing
103  // new mask index) only happens on the last dimension of the vectors.
104  SmallVector<int64_t> maskShape(
105  cast<VectorType>(maskOp->getResultTypes()[0]).getShape());
106  maskShape.back() = numDestElems;
107  auto newMaskType = VectorType::get(maskShape, rewriter.getI1Type());
108  std::optional<Operation *> newMask =
110  .Case<vector::CreateMaskOp>(
111  [&](auto createMaskOp) -> std::optional<Operation *> {
112  OperandRange maskOperands = createMaskOp.getOperands();
113  // The `vector.create_mask` op creates a mask arrangement
114  // without any zeros at the front. Also, because
115  // `numFrontPadElems` is strictly smaller than
116  // `numSrcElemsPerDest`, the compressed mask generated by
117  // padding the original mask by `numFrontPadElems` will not
118  // have any zeros at the front as well.
119  AffineExpr s0;
120  bindSymbols(rewriter.getContext(), s0);
121  s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
122  OpFoldResult origIndex = getAsOpFoldResult(maskOperands.back());
124  rewriter, loc, s0, origIndex);
125  SmallVector<Value> newMaskOperands(maskOperands.drop_back());
126  newMaskOperands.push_back(
127  getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
128  return rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
129  newMaskOperands);
130  })
131  .Case<vector::ConstantMaskOp>([&](auto constantMaskOp)
132  -> std::optional<Operation *> {
133  ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
134  size_t numMaskOperands = maskDimSizes.size();
135  int64_t origIndex = maskDimSizes[numMaskOperands - 1];
136  int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
137  int64_t maskIndex = llvm::divideCeil(numFrontPadElems + origIndex,
138  numSrcElemsPerDest);
139 
140  // TODO: we only want the mask between [startIndex, maskIndex]
141  // to be true, the rest are false.
142  if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
143  return std::nullopt;
144 
145  SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
146  newMaskDimSizes.push_back(maskIndex);
147 
148  if (numFrontPadElems == 0)
149  return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
150  newMaskDimSizes);
151 
152  SmallVector<bool> newMaskValues;
153  for (int64_t i = 0; i < numDestElems; ++i)
154  newMaskValues.push_back(i >= startIndex && i < maskIndex);
155  auto newMask = DenseElementsAttr::get(newMaskType, newMaskValues);
156  return rewriter.create<arith::ConstantOp>(loc, newMaskType,
157  newMask);
158  })
159  .Case<arith::ConstantOp>([&](auto constantOp)
160  -> std::optional<Operation *> {
161  // TODO: Support multiple dimensions.
162  if (maskShape.size() != 1)
163  return std::nullopt;
164  // Rearrange the original mask values to cover the whole potential
165  // loading region. For example, in the case of using byte-size for
166  // emulation, given the following mask:
167  //
168  // %mask = [0, 1, 0, 1, 0, 0]
169  //
170  // With front offset of 1, the mask will be padded 0s in the front
171  // and back so that:
172  // 1. It is aligned with the effective loading bits
173  // 2. Its length is multiple of `numSrcElemPerDest` (and the total
174  // coverage size is mulitiple of bytes). The new mask will be like
175  // this before compressing:
176  //
177  // %new_mask = [0, 0, 1, 0, 1, 0, 0, 0]
178  auto originalMask =
179  cast<DenseIntElementsAttr>(constantOp.getValue());
180  SmallVector<bool> paddedMaskValues(numFrontPadElems, false);
181  paddedMaskValues.append(originalMask.template value_begin<bool>(),
182  originalMask.template value_end<bool>());
183  paddedMaskValues.resize(numDestElems * numSrcElemsPerDest, false);
184 
185  // Compressing by combining every `numSrcElemsPerDest` elements:
186  SmallVector<bool> compressedMaskValues;
187  for (size_t i = 0; i < paddedMaskValues.size();
188  i += numSrcElemsPerDest) {
189  bool combinedValue = false;
190  for (int j = 0; j < numSrcElemsPerDest; ++j) {
191  combinedValue |= paddedMaskValues[i + j];
192  }
193  compressedMaskValues.push_back(combinedValue);
194  }
195  return rewriter.create<arith::ConstantOp>(
196  loc, DenseElementsAttr::get(newMaskType, compressedMaskValues));
197  });
198 
199  if (!newMask)
200  return failure();
201 
202  while (!extractOps.empty()) {
203  newMask = rewriter.create<vector::ExtractOp>(
204  loc, (*newMask)->getResults()[0], extractOps.back().getMixedPosition());
205  extractOps.pop_back();
206  }
207 
208  return *newMask;
209 }
210 
211 /// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
212 /// emitting `vector.extract_strided_slice`.
214  VectorType extractType, Value source,
215  int64_t frontOffset,
216  int64_t subvecSize) {
217  auto vectorType = cast<VectorType>(source.getType());
218  assert((vectorType.getRank() == 1 && extractType.getRank() == 1) &&
219  "expected 1-D source and destination types");
220  (void)vectorType;
221  assert(frontOffset + subvecSize <= vectorType.getNumElements() &&
222  "subvector out of bounds");
223 
224  // do not need extraction if the subvector size is the same as the source
225  if (vectorType.getNumElements() == subvecSize)
226  return source;
227 
228  auto offsets = rewriter.getI64ArrayAttr({frontOffset});
229  auto sizes = rewriter.getI64ArrayAttr({subvecSize});
230  auto strides = rewriter.getI64ArrayAttr({1});
231  return rewriter
232  .create<vector::ExtractStridedSliceOp>(loc, extractType, source, offsets,
233  sizes, strides)
234  ->getResult(0);
235 }
236 
237 /// Inserts 1-D subvector into a 1-D vector by overwriting the elements starting
238 /// at `offset`. it is a wrapper function for emitting
239 /// `vector.insert_strided_slice`.
241  Value src, Value dest, int64_t offset) {
242  [[maybe_unused]] auto srcType = cast<VectorType>(src.getType());
243  [[maybe_unused]] auto destType = cast<VectorType>(dest.getType());
244  assert(srcType.getRank() == 1 && destType.getRank() == 1 &&
245  "expected source and dest to be vector type");
246  auto offsets = rewriter.getI64ArrayAttr({offset});
247  auto strides = rewriter.getI64ArrayAttr({1});
248  return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
249  dest, offsets, strides);
250 }
251 
252 /// Extracts a 1-D subvector from a 1-D `source` vector, with index at `offset`
253 /// and size `numElementsToExtract`, and inserts into the `dest` vector. This
254 /// function emits multiple `vector.extract` and `vector.insert` ops, so only
255 /// use it when `offset` cannot be folded into a constant value.
257  TypedValue<VectorType> source,
258  Value dest, OpFoldResult offset,
259  int64_t numElementsToExtract) {
260  for (int i = 0; i < numElementsToExtract; ++i) {
261  Value extractLoc =
262  (i == 0) ? offset.dyn_cast<Value>()
263  : rewriter.create<arith::AddIOp>(
264  loc, rewriter.getIndexType(), offset.dyn_cast<Value>(),
265  rewriter.create<arith::ConstantIndexOp>(loc, i));
266  auto extractOp =
267  rewriter.create<vector::ExtractOp>(loc, source, extractLoc);
268  dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, i);
269  }
270  return dest;
271 }
272 
273 /// Inserts a 1-D subvector into a 1-D `dest` vector at index `destOffsetVar`.
275  TypedValue<VectorType> source,
276  Value dest, OpFoldResult destOffsetVar,
277  size_t length) {
278  assert(length > 0 && "length must be greater than 0");
279  Value destOffsetVal =
280  getValueOrCreateConstantIndexOp(rewriter, loc, destOffsetVar);
281  for (size_t i = 0; i < length; ++i) {
282  auto insertLoc = i == 0
283  ? destOffsetVal
284  : rewriter.create<arith::AddIOp>(
285  loc, rewriter.getIndexType(), destOffsetVal,
286  rewriter.create<arith::ConstantIndexOp>(loc, i));
287  auto extractOp = rewriter.create<vector::ExtractOp>(loc, source, i);
288  dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
289  }
290  return dest;
291 }
292 
293 /// Returns the op sequence for an emulated sub-byte data type vector load.
294 /// specifically, use `emulatedElemType` for loading a vector of `origElemType`.
295 /// The load location is given by `base` and `linearizedIndices`, and the
296 /// load size is given by `numEmulatedElementsToLoad`.
299  OpFoldResult linearizedIndices,
300  int64_t numEmultedElementsToLoad, Type origElemType,
301  Type emulatedElemType) {
302  auto scale = emulatedElemType.getIntOrFloatBitWidth() /
303  origElemType.getIntOrFloatBitWidth();
304  auto newLoad = rewriter.create<vector::LoadOp>(
305  loc, VectorType::get(numEmultedElementsToLoad, emulatedElemType), base,
306  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
307  return rewriter.create<vector::BitCastOp>(
308  loc, VectorType::get(numEmultedElementsToLoad * scale, origElemType),
309  newLoad);
310 }
311 
312 namespace {
313 
314 //===----------------------------------------------------------------------===//
315 // ConvertVectorStore
316 //===----------------------------------------------------------------------===//
317 
318 struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
320 
321  LogicalResult
322  matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
323  ConversionPatternRewriter &rewriter) const override {
324 
325  // See #115653
326  if (op.getValueToStore().getType().getRank() != 1)
327  return rewriter.notifyMatchFailure(op,
328  "only 1-D vectors are supported ATM");
329 
330  auto loc = op.getLoc();
331  auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
332  Type oldElementType = op.getValueToStore().getType().getElementType();
333  Type newElementType = convertedType.getElementType();
334  int srcBits = oldElementType.getIntOrFloatBitWidth();
335  int dstBits = newElementType.getIntOrFloatBitWidth();
336 
337  if (dstBits % srcBits != 0) {
338  return rewriter.notifyMatchFailure(
339  op, "only dstBits % srcBits == 0 supported");
340  }
341  int scale = dstBits / srcBits;
342 
343  // Adjust the number of elements to store when emulating narrow types.
344  // Here only the 1-D vector store is considered, and the N-D memref types
345  // should be linearized.
346  // For example, to emulate i4 to i8, the following op:
347  //
348  // vector.store %arg1, %0[%arg2, %arg3] : memref<4x8xi4>, vector<8xi4>
349  //
350  // can be replaced with
351  //
352  // %bitcast = vector.bitcast %arg1 : vector<8xi4> to vector<4xi8>
353  // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
354  // vector<4xi8>
355 
356  auto origElements = op.getValueToStore().getType().getNumElements();
357  if (origElements % scale != 0)
358  return failure();
359 
360  auto stridedMetadata =
361  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
362 
363  OpFoldResult linearizedIndices;
364  std::tie(std::ignore, linearizedIndices) =
366  rewriter, loc, srcBits, dstBits,
367  stridedMetadata.getConstifiedMixedOffset(),
368  stridedMetadata.getConstifiedMixedSizes(),
369  stridedMetadata.getConstifiedMixedStrides(),
370  getAsOpFoldResult(adaptor.getIndices()));
371 
372  auto numElements = origElements / scale;
373  auto bitCast = rewriter.create<vector::BitCastOp>(
374  loc, VectorType::get(numElements, newElementType),
375  op.getValueToStore());
376 
377  rewriter.replaceOpWithNewOp<vector::StoreOp>(
378  op, bitCast.getResult(), adaptor.getBase(),
379  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
380  return success();
381  }
382 };
383 
384 //===----------------------------------------------------------------------===//
385 // ConvertVectorMaskedStore
386 //===----------------------------------------------------------------------===//
387 
388 struct ConvertVectorMaskedStore final
389  : OpConversionPattern<vector::MaskedStoreOp> {
391 
392  LogicalResult
393  matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
394  ConversionPatternRewriter &rewriter) const override {
395 
396  // See #115653
397  if (op.getValueToStore().getType().getRank() != 1)
398  return rewriter.notifyMatchFailure(op,
399  "only 1-D vectors are supported ATM");
400 
401  auto loc = op.getLoc();
402  auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
403  Type oldElementType = op.getValueToStore().getType().getElementType();
404  Type newElementType = convertedType.getElementType();
405  int srcBits = oldElementType.getIntOrFloatBitWidth();
406  int dstBits = newElementType.getIntOrFloatBitWidth();
407 
408  if (dstBits % srcBits != 0) {
409  return rewriter.notifyMatchFailure(
410  op, "only dstBits % srcBits == 0 supported");
411  }
412 
413  int scale = dstBits / srcBits;
414  int origElements = op.getValueToStore().getType().getNumElements();
415  if (origElements % scale != 0)
416  return failure();
417 
418  auto stridedMetadata =
419  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
420  OpFoldResult linearizedIndicesOfr;
421  memref::LinearizedMemRefInfo linearizedInfo;
422  std::tie(linearizedInfo, linearizedIndicesOfr) =
424  rewriter, loc, srcBits, dstBits,
425  stridedMetadata.getConstifiedMixedOffset(),
426  stridedMetadata.getConstifiedMixedSizes(),
427  stridedMetadata.getConstifiedMixedStrides(),
428  getAsOpFoldResult(adaptor.getIndices()));
429  Value linearizedIndices =
430  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
431 
432  // Load the whole data and use arith.select to handle the corner cases.
433  //
434  // As an example, for this masked store of i4 values:
435  //
436  // vector.maskedstore %0[%c0, %c0], %mask, %val_to_store
437  //
438  // and given these input values:
439  //
440  // %mask = [0, 1, 1, 1, 1, 0, 0, 0] (8 * i1)
441  // %0[%c0, %c0] =
442  // [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4)
443  // %val_to_store =
444  // [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0] (8 * i4)
445  //
446  // we'll have the following i4 output:
447  //
448  // expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x7, 0x8]
449  //
450  // Emulating the above using i8 will give:
451  //
452  // %compressed_mask = [1, 1, 1, 0] (4 * i1)
453  // %maskedload = [0x12, 0x34, 0x56, 0x00] (4 * i8)
454  // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0] (8 * i4)
455  // %select_using_shifted_mask =
456  // [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x0, 0x0] (8 * i4)
457  // %packed_data = [0x1A, 0xBC, 0xD6, 0x00] (4 * i8)
458  //
459  // Using the compressed mask to store %packed_data results in expected
460  // output.
461  //
462  // FIXME: Make an example based on the comment above work (see #115460 for
463  // reproducer).
464  FailureOr<Operation *> newMask =
465  getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
466  if (failed(newMask))
467  return failure();
468 
469  auto numElements = (origElements + scale - 1) / scale;
470  auto newType = VectorType::get(numElements, newElementType);
471  auto passThru = rewriter.create<arith::ConstantOp>(
472  loc, newType, rewriter.getZeroAttr(newType));
473 
474  auto newLoad = rewriter.create<vector::MaskedLoadOp>(
475  loc, newType, adaptor.getBase(), linearizedIndices,
476  newMask.value()->getResult(0), passThru);
477 
478  auto newBitCastType = VectorType::get(numElements * scale, oldElementType);
479  Value valueToStore =
480  rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
481  valueToStore = rewriter.create<arith::SelectOp>(
482  loc, op.getMask(), op.getValueToStore(), valueToStore);
483  valueToStore =
484  rewriter.create<vector::BitCastOp>(loc, newType, valueToStore);
485 
486  rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
487  op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
488  valueToStore);
489  return success();
490  }
491 };
492 
493 //===----------------------------------------------------------------------===//
494 // ConvertVectorLoad
495 //===----------------------------------------------------------------------===//
496 
497 struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
499 
500  LogicalResult
501  matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
502  ConversionPatternRewriter &rewriter) const override {
503 
504  // See #115653
505  if (op.getVectorType().getRank() != 1)
506  return rewriter.notifyMatchFailure(op,
507  "only 1-D vectors are supported ATM");
508 
509  auto loc = op.getLoc();
510  auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
511  Type oldElementType = op.getType().getElementType();
512  Type newElementType = convertedType.getElementType();
513  int srcBits = oldElementType.getIntOrFloatBitWidth();
514  int dstBits = newElementType.getIntOrFloatBitWidth();
515 
516  if (dstBits % srcBits != 0) {
517  return rewriter.notifyMatchFailure(
518  op, "only dstBits % srcBits == 0 supported");
519  }
520  int scale = dstBits / srcBits;
521 
522  // Adjust the number of elements to load when emulating narrow types,
523  // and then cast back to the original type with vector.bitcast op.
524  // Here only the 1-D vector load is considered, and the N-D memref types
525  // should be linearized.
526  // For example, to emulate i4 to i8, the following op:
527  //
528  // %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
529  //
530  // can be replaced with
531  //
532  // %1 = vector.load %0[%linear_index] : memref<6xi8>, vector<2xi8>
533  // %2 = vector.bitcast %1 : vector<2xi8> to vector<4xi4>
534  //
535  // There are cases where the number of elements to load is not byte-aligned,
536  // for example:
537  //
538  // %1 = vector.load %0[%c1, %c0] : memref<3x3xi2>, vector<3xi2>
539  //
540  // we will have to load extra bytes and extract the exact slice in between.
541  //
542  // %1 = vector.load %0[%c2] : memref<3xi8>, vector<2xi8>
543  // %2 = vector.bitcast %1 : vector<2xi8> to vector<8xi2>
544  // %3 = vector.extract_strided_slice %1 {offsets = [2], sizes = [3], strides
545  // = [1]}
546  // : vector<8xi2> to vector<3xi2>
547  //
548  // TODO: Currently the extract_strided_slice's attributes must be known at
549  // compile time as they must be constants.
550 
551  auto origElements = op.getVectorType().getNumElements();
552  bool isUnalignedEmulation = origElements % scale != 0;
553 
554  auto stridedMetadata =
555  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
556 
557  OpFoldResult linearizedIndices;
558  memref::LinearizedMemRefInfo linearizedInfo;
559  std::tie(linearizedInfo, linearizedIndices) =
561  rewriter, loc, srcBits, dstBits,
562  stridedMetadata.getConstifiedMixedOffset(),
563  stridedMetadata.getConstifiedMixedSizes(),
564  stridedMetadata.getConstifiedMixedStrides(),
565  getAsOpFoldResult(adaptor.getIndices()));
566 
567  std::optional<int64_t> foldedIntraVectorOffset =
568  isUnalignedEmulation
569  ? getConstantIntValue(linearizedInfo.intraDataOffset)
570  : 0;
571 
572  // Always load enough elements which can cover the original elements.
573  int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
574  auto numElements =
575  llvm::divideCeil(maxintraDataOffset + origElements, scale);
576  Value result =
577  emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
578  numElements, oldElementType, newElementType);
579 
580  if (!foldedIntraVectorOffset) {
581  auto resultVector = rewriter.create<arith::ConstantOp>(
582  loc, op.getType(), rewriter.getZeroAttr(op.getType()));
584  rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
585  linearizedInfo.intraDataOffset, origElements);
586  } else if (isUnalignedEmulation) {
587  result =
588  staticallyExtractSubvector(rewriter, loc, op.getType(), result,
589  *foldedIntraVectorOffset, origElements);
590  }
591  rewriter.replaceOp(op, result);
592  return success();
593  }
594 };
595 
596 //===----------------------------------------------------------------------===//
597 // ConvertVectorMaskedLoad
598 //===----------------------------------------------------------------------===//
599 
600 struct ConvertVectorMaskedLoad final
601  : OpConversionPattern<vector::MaskedLoadOp> {
603 
604  LogicalResult
605  matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
606  ConversionPatternRewriter &rewriter) const override {
607 
608  // See #115653
609  if (op.getVectorType().getRank() != 1)
610  return rewriter.notifyMatchFailure(op,
611  "only 1-D vectors are supported ATM");
612 
613  auto loc = op.getLoc();
614  auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
615  Type oldElementType = op.getType().getElementType();
616  Type newElementType = convertedType.getElementType();
617  int srcBits = oldElementType.getIntOrFloatBitWidth();
618  int dstBits = newElementType.getIntOrFloatBitWidth();
619 
620  if (dstBits % srcBits != 0) {
621  return rewriter.notifyMatchFailure(
622  op, "only dstBits % srcBits == 0 supported");
623  }
624  int scale = dstBits / srcBits;
625 
626  // Adjust the number of elements to load when emulating narrow types,
627  // and then cast back to the original type with vector.bitcast op.
628  // For example, to emulate i4 to i8, the following op:
629  //
630  // %mask = vector.constant_mask [3] : vector<6xi1>
631  // %1 = vector.maskedload %0[%c0, %c0], %mask, %pass_thru :
632  // memref<3x6xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
633  //
634  // can be replaced with
635  //
636  // %new_mask = vector.constant_mask [2] : vector<3xi1>
637  // %new_pass_thru = vector.bitcast %pass_thru :
638  // vector<6xi4> to vector<3xi8>
639  // %1 = vector.maskedload %0[%linear_index], %new_mask, %new_pass_thru :
640  // memref<9xi8>, vector<3xi1>, vector<3xi8> into vector<3xi8>
641  // %2 = vector.bitcast %1 : vector<3xi8> to vector<6xi4>
642  //
643  // Since we are effectively loading 16 bits (2xi8) from the memref with the
644  // new mask, while originally we only wanted to effectively load 12 bits
645  // (3xi4) from the memref, we need to set the second half of the last i8
646  // that was effectively loaded (i.e. the second i8) to %pass_thru.
647  //
648  // %3 = arith.select %mask, %2, %pass_thru : vector<6xi1>, vector<6xi4>
649  //
650  // Given these input values:
651  // %mask = [1, 1, 1, 0, 0, 0]
652  // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
653  // %pass_thru = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
654  //
655  // we'll have:
656  //
657  // expected output: [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
658  //
659  // %new_mask = [1, 1, 0]
660  // %new_pass_thru = [0x78, 0x9A, 0xBC]
661  // %1 = [0x12, 0x34, 0xBC]
662  // %2 = [0x1, 0x2, 0x3, 0x4, 0xB, 0xC]
663  // %3 = [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
664  //
665  // TODO: Currently, only the even number of elements loading is supported.
666  // To deal with the odd number of elements, one has to extract the
667  // subvector at the proper offset after bit-casting.
668  auto origType = op.getVectorType();
669  auto origElements = origType.getNumElements();
670  bool isUnalignedEmulation = origElements % scale != 0;
671 
672  auto stridedMetadata =
673  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
674  OpFoldResult linearizedIndices;
675  memref::LinearizedMemRefInfo linearizedInfo;
676  std::tie(linearizedInfo, linearizedIndices) =
678  rewriter, loc, srcBits, dstBits,
679  stridedMetadata.getConstifiedMixedOffset(),
680  stridedMetadata.getConstifiedMixedSizes(),
681  stridedMetadata.getConstifiedMixedStrides(),
682  getAsOpFoldResult(adaptor.getIndices()));
683 
684  std::optional<int64_t> foldedIntraVectorOffset =
685  isUnalignedEmulation
686  ? getConstantIntValue(linearizedInfo.intraDataOffset)
687  : 0;
688 
689  int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
690  FailureOr<Operation *> newMask = getCompressedMaskOp(
691  rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
692  if (failed(newMask))
693  return failure();
694 
695  Value passthru = op.getPassThru();
696 
697  auto numElements =
698  llvm::divideCeil(maxIntraDataOffset + origElements, scale);
699  auto loadType = VectorType::get(numElements, newElementType);
700  auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
701 
702  auto emptyVector = rewriter.create<arith::ConstantOp>(
703  loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
704  if (!foldedIntraVectorOffset) {
705  passthru = dynamicallyInsertSubVector(
706  rewriter, loc, dyn_cast<TypedValue<VectorType>>(passthru),
707  emptyVector, linearizedInfo.intraDataOffset, origElements);
708  } else if (isUnalignedEmulation) {
709  passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
710  *foldedIntraVectorOffset);
711  }
712  auto newPassThru =
713  rewriter.create<vector::BitCastOp>(loc, loadType, passthru);
714 
715  // Generating the new masked load.
716  auto newLoad = rewriter.create<vector::MaskedLoadOp>(
717  loc, loadType, adaptor.getBase(),
718  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
719  newMask.value()->getResult(0), newPassThru);
720 
721  // Setting the part that originally was not effectively loaded from memory
722  // to pass through.
723  auto bitCast =
724  rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
725 
726  Value mask = op.getMask();
727  auto newSelectMaskType =
728  VectorType::get(numElements * scale, rewriter.getI1Type());
729  // TODO: try to fold if op's mask is constant
730  auto emptyMask = rewriter.create<arith::ConstantOp>(
731  loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
732  if (!foldedIntraVectorOffset) {
734  rewriter, loc, dyn_cast<TypedValue<VectorType>>(mask), emptyMask,
735  linearizedInfo.intraDataOffset, origElements);
736  } else if (isUnalignedEmulation) {
737  mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
738  *foldedIntraVectorOffset);
739  }
740 
741  Value result =
742  rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
743  if (!foldedIntraVectorOffset) {
745  rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
746  op.getPassThru(), linearizedInfo.intraDataOffset, origElements);
747  } else if (isUnalignedEmulation) {
748  result =
749  staticallyExtractSubvector(rewriter, loc, op.getType(), result,
750  *foldedIntraVectorOffset, origElements);
751  }
752  rewriter.replaceOp(op, result);
753 
754  return success();
755  }
756 };
757 
758 //===----------------------------------------------------------------------===//
759 // ConvertVectorTransferRead
760 //===----------------------------------------------------------------------===//
761 
762 struct ConvertVectorTransferRead final
763  : OpConversionPattern<vector::TransferReadOp> {
765 
766  LogicalResult
767  matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
768  ConversionPatternRewriter &rewriter) const override {
769 
770  // See #115653
771  if (op.getVectorType().getRank() != 1)
772  return rewriter.notifyMatchFailure(op,
773  "only 1-D vectors are supported ATM");
774 
775  auto loc = op.getLoc();
776  auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
777  Type oldElementType = op.getType().getElementType();
778  Type newElementType = convertedType.getElementType();
779  int srcBits = oldElementType.getIntOrFloatBitWidth();
780  int dstBits = newElementType.getIntOrFloatBitWidth();
781 
782  if (dstBits % srcBits != 0) {
783  return rewriter.notifyMatchFailure(
784  op, "only dstBits % srcBits == 0 supported");
785  }
786  int scale = dstBits / srcBits;
787 
788  auto origElements = op.getVectorType().getNumElements();
789 
790  bool isUnalignedEmulation = origElements % scale != 0;
791 
792  auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
793  adaptor.getPadding());
794 
795  auto stridedMetadata =
796  rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
797 
798  OpFoldResult linearizedIndices;
799  memref::LinearizedMemRefInfo linearizedInfo;
800  std::tie(linearizedInfo, linearizedIndices) =
802  rewriter, loc, srcBits, dstBits,
803  stridedMetadata.getConstifiedMixedOffset(),
804  stridedMetadata.getConstifiedMixedSizes(),
805  stridedMetadata.getConstifiedMixedStrides(),
806  getAsOpFoldResult(adaptor.getIndices()));
807 
808  std::optional<int64_t> foldedIntraVectorOffset =
809  isUnalignedEmulation
810  ? getConstantIntValue(linearizedInfo.intraDataOffset)
811  : 0;
812 
813  int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
814  auto numElements =
815  llvm::divideCeil(maxIntraDataOffset + origElements, scale);
816 
817  auto newRead = rewriter.create<vector::TransferReadOp>(
818  loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
819  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
820  newPadding);
821 
822  auto bitCast = rewriter.create<vector::BitCastOp>(
823  loc, VectorType::get(numElements * scale, oldElementType), newRead);
824 
825  Value result = bitCast->getResult(0);
826  if (!foldedIntraVectorOffset) {
827  auto zeros = rewriter.create<arith::ConstantOp>(
828  loc, op.getType(), rewriter.getZeroAttr(op.getType()));
829  result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
830  linearizedInfo.intraDataOffset,
831  origElements);
832  } else if (isUnalignedEmulation) {
833  result =
834  staticallyExtractSubvector(rewriter, loc, op.getType(), result,
835  *foldedIntraVectorOffset, origElements);
836  }
837  rewriter.replaceOp(op, result);
838 
839  return success();
840  }
841 };
842 } // end anonymous namespace
843 
844 //===----------------------------------------------------------------------===//
845 // RewriteBitCastOfTruncI
846 //===----------------------------------------------------------------------===//
847 
848 namespace {
849 
850 /// Helper struct to keep track of the provenance of a contiguous set of bits
851 /// in a source vector.
852 struct SourceElementRange {
853  /// The index of the source vector element that contributes bits to *this.
854  int64_t sourceElementIdx;
855  /// The range of bits in the source vector element that contribute to *this.
856  int64_t sourceBitBegin;
857  int64_t sourceBitEnd;
858 };
859 
860 struct SourceElementRangeList : public SmallVector<SourceElementRange> {
861  /// Given the index of a SourceElementRange in the SourceElementRangeList,
862  /// compute the amount of bits that need to be shifted to the left to get the
863  /// bits in their final location. This shift amount is simply the sum of the
864  /// bits *before* `shuffleIdx` (i.e. the bits of `shuffleIdx = 0` are always
865  /// the LSBs, the bits of `shuffleIdx = ` come next, etc).
866  int64_t computeLeftShiftAmount(int64_t shuffleIdx) const {
867  int64_t res = 0;
868  for (int64_t i = 0; i < shuffleIdx; ++i)
869  res += (*this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
870  return res;
871  }
872 };
873 
874 /// Helper struct to enumerate the source elements and bit ranges that are
875 /// involved in a bitcast operation.
876 /// This allows rewriting a vector.bitcast into shuffles and bitwise ops for
877 /// any 1-D vector shape and any source/target bitwidths.
878 /// This creates and holds a mapping of the form:
879 /// [dstVectorElementJ] ==
880 /// [ {srcVectorElementX, bitRange}, {srcVectorElementY, bitRange}, ... ]
881 /// E.g. `vector.bitcast ... : vector<1xi24> to vector<3xi8>` is decomposed as:
882 /// [0] = {0, [0-8)}
883 /// [1] = {0, [8-16)}
884 /// [2] = {0, [16-24)}
885 /// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as:
886 /// [0] = {0, [0, 10)}, {1, [0, 5)}
887 /// [1] = {1, [5, 10)}, {2, [0, 10)}
888 struct BitCastBitsEnumerator {
889  BitCastBitsEnumerator(VectorType sourceVectorType,
890  VectorType targetVectorType);
891 
892  int64_t getMaxNumberOfEntries() {
893  int64_t numVectors = 0;
894  for (const auto &l : sourceElementRanges)
895  numVectors = std::max(numVectors, (int64_t)l.size());
896  return numVectors;
897  }
898 
899  VectorType sourceVectorType;
900  VectorType targetVectorType;
901  SmallVector<SourceElementRangeList> sourceElementRanges;
902 };
903 
904 /// Rewrite vector.bitcast to a sequence of shuffles and bitwise ops that take
905 /// advantage of high-level information to avoid leaving LLVM to scramble with
906 /// peephole optimizations.
907 /// BitCastBitsEnumerator encodes for each element of the target vector the
908 /// provenance of the bits in the source vector. We can "transpose" this
909 /// information to build a sequence of shuffles and bitwise ops that will
910 /// produce the desired result.
911 //
912 /// Consider the following motivating example:
913 /// ```
914 /// %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
915 /// ```
916 //
917 /// BitCastBitsEnumerator contains the following information:
918 /// ```
919 /// { 0: b@[0..5) lshl: 0}{ 1: b@[0..3) lshl: 5}
920 /// { 1: b@[3..5) lshl: 0}{ 2: b@[0..5) lshl: 2}{ 3: b@[0..1) lshl: 7}
921 /// { 3: b@[1..5) lshl: 0}{ 4: b@[0..4) lshl: 4}
922 /// { 4: b@[4..5) lshl: 0}{ 5: b@[0..5) lshl: 1}{ 6: b@[0..2) lshl: 6}
923 /// { 6: b@[2..5) lshl: 0}{ 7: b@[0..5) lshl: 3}
924 /// { 8: b@[0..5) lshl: 0}{ 9: b@[0..3) lshl: 5}
925 /// { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7}
926 /// {11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4}
927 /// {12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6}
928 /// {14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
929 /// {16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
930 /// {17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
931 /// {19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
932 /// {20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1}{22: b@[0..2) lshl: 6}
933 /// {22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3}
934 /// {24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5}
935 /// {25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7}
936 /// {27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
937 /// {28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
938 /// {30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3}
939 /// ```
940 ///
941 /// In the above, each row represents one target vector element and each
942 /// column represents one bit contribution from a source vector element.
943 /// The algorithm creates vector.shuffle operations (in this case there are 3
944 /// shuffles (i.e. the max number of columns in BitCastBitsEnumerator). The
945 /// algorithm populates the bits as follows:
946 /// ```
947 /// src bits 0 ...
948 /// 1st shuffle |xxxxx |xx |...
949 /// 2nd shuffle | xxx| xxxxx |...
950 /// 3rd shuffle | | x|...
951 /// ```
952 //
953 /// The algorithm proceeds as follows:
954 /// 1. for each vector.shuffle, collect the source vectors that participate in
955 /// this shuffle. One source vector per target element of the resulting
956 /// vector.shuffle. If there is no source element contributing bits for the
957 /// current vector.shuffle, take 0 (i.e. row 0 in the above example has only
958 /// 2 columns).
959 /// 2. represent the bitrange in the source vector as a mask. If there is no
960 /// source element contributing bits for the current vector.shuffle, take 0.
961 /// 3. shift right by the proper amount to align the source bitrange at
962 /// position 0. This is exactly the low end of the bitrange. For instance,
963 /// the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
964 /// shift right by 3 to get the bits contributed by the source element #1
965 /// into position 0.
966 /// 4. shift left by the proper amount to to align to the desired position in
967 /// the result element vector. For instance, the contribution of the second
968 /// source element for the first row needs to be shifted by `5` to form the
969 /// first i8 result element.
970 ///
971 /// Eventually, we end up building the sequence
972 /// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
973 /// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
974 /// bits extracted from the source vector (i.e. the `shuffle -> and` part).
975 struct BitCastRewriter {
976  /// Helper metadata struct to hold the static quantities for the rewrite.
977  struct Metadata {
978  SmallVector<int64_t> shuffles;
979  SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
980  };
981 
982  BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
983 
984  /// Verify that general preconditions for the rewrite are met.
985  LogicalResult commonPrecondition(PatternRewriter &rewriter,
986  VectorType preconditionType, Operation *op);
987 
988  /// Precompute the metadata for the rewrite.
990  precomputeMetadata(IntegerType shuffledElementType);
991 
992  /// Rewrite one step of the sequence:
993  /// `(shuffle -> and -> shiftright -> shiftleft -> or)`.
994  Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
995  Value initialValue, Value runningResult,
996  const BitCastRewriter::Metadata &metadata);
997 
998 private:
999  /// Underlying enumerator that encodes the provenance of the bits in the each
1000  /// element of the result vector.
1001  BitCastBitsEnumerator enumerator;
1002 };
1003 
1004 } // namespace
1005 
1006 [[maybe_unused]] static raw_ostream &
1007 operator<<(raw_ostream &os, const SmallVector<SourceElementRangeList> &vec) {
1008  for (const auto &l : vec) {
1009  for (auto it : llvm::enumerate(l)) {
1010  os << "{ " << it.value().sourceElementIdx << ": b@["
1011  << it.value().sourceBitBegin << ".." << it.value().sourceBitEnd
1012  << ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } ";
1013  }
1014  os << "\n";
1015  }
1016  return os;
1017 }
1018 
1019 BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
1020  VectorType targetVectorType)
1021  : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
1022 
1023  assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
1024  "requires -D non-scalable vector type");
1025  assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
1026  "requires -D non-scalable vector type");
1027  int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
1028  int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
1029  LDBG("sourceVectorType: " << sourceVectorType);
1030 
1031  int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
1032  int64_t mostMinorTargetDim = targetVectorType.getShape().back();
1033  LDBG("targetVectorType: " << targetVectorType);
1034 
1035  int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
1036  (void)mostMinorSourceDim;
1037  assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
1038  "source and target bitwidths must match");
1039 
1040  // Prepopulate one source element range per target element.
1041  sourceElementRanges = SmallVector<SourceElementRangeList>(mostMinorTargetDim);
1042  for (int64_t resultBit = 0; resultBit < bitwidth;) {
1043  int64_t resultElement = resultBit / targetBitWidth;
1044  int64_t resultBitInElement = resultBit % targetBitWidth;
1045  int64_t sourceElementIdx = resultBit / sourceBitWidth;
1046  int64_t sourceBitInElement = resultBit % sourceBitWidth;
1047  int64_t step = std::min(sourceBitWidth - sourceBitInElement,
1048  targetBitWidth - resultBitInElement);
1049  sourceElementRanges[resultElement].push_back(
1050  {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
1051  resultBit += step;
1052  }
1053 }
1054 
1055 BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
1056  VectorType targetVectorType)
1057  : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
1058  LDBG("\n" << enumerator.sourceElementRanges);
1059 }
1060 
1061 /// Verify that the precondition type meets the common preconditions for any
1062 /// conversion.
1063 static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
1064  VectorType preconditionType,
1065  Operation *op) {
1066  if (!preconditionType || preconditionType.isScalable())
1067  return rewriter.notifyMatchFailure(op, "scalable vector");
1068 
1069  // TODO: consider relaxing this restriction in the future if we find ways
1070  // to really work with subbyte elements across the MLIR/LLVM boundary.
1071  unsigned bitwidth = preconditionType.getElementTypeBitWidth();
1072  if (bitwidth % 8 != 0)
1073  return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
1074 
1075  return success();
1076 }
1077 
1078 LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
1079  VectorType preconditionType,
1080  Operation *op) {
1081  if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
1082  return rewriter.notifyMatchFailure(op, "types are not vector");
1083 
1084  if (!preconditionType || preconditionType.getRank() != 1)
1085  return rewriter.notifyMatchFailure(op, "unsupported >1-D vector");
1086 
1087  return commonConversionPrecondition(rewriter, preconditionType, op);
1088 }
1089 
1090 /// Verify that source and destination element types meet the precondition for
1091 /// the supported aligned conversion cases. Alignment means that the either the
1092 /// source element type is multiple of the destination element type or the other
1093 /// way around.
1094 ///
1095 /// NOTE: This method assumes that common conversion preconditions are met.
1096 static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
1097  VectorType srcType,
1098  VectorType dstType,
1099  Operation *op) {
1100  if (!srcType || !dstType)
1101  return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
1102  unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
1103  unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
1104 
1105  // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
1106  if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
1107  (dstElemBitwidth % srcElemBitwidth) != 0)
1108  return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
1109 
1110  if ((srcType.getShape().back() % 2) != 0)
1111  return rewriter.notifyMatchFailure(
1112  op, "Not an even number of i4 elements in trailing dim");
1113 
1114  return success();
1115 }
1116 
1118 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
1120  for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
1121  shuffleIdx < e; ++shuffleIdx) {
1122  SmallVector<int64_t> shuffles;
1123  SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1124 
1125  // Create the attribute quantities for the shuffle / mask / shift ops.
1126  for (auto &srcEltRangeList : enumerator.sourceElementRanges) {
1127  int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
1128  ? srcEltRangeList[shuffleIdx].sourceElementIdx
1129  : 0;
1130  shuffles.push_back(sourceElement);
1131 
1132  int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
1133  ? srcEltRangeList[shuffleIdx].sourceBitBegin
1134  : 0;
1135  int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
1136  ? srcEltRangeList[shuffleIdx].sourceBitEnd
1137  : 0;
1138  IntegerAttr mask = IntegerAttr::get(
1139  shuffledElementType,
1140  llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
1141  bitLo, bitHi));
1142  masks.push_back(mask);
1143 
1144  int64_t shiftRight = bitLo;
1145  shiftRightAmounts.push_back(
1146  IntegerAttr::get(shuffledElementType, shiftRight));
1147 
1148  int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
1149  shiftLeftAmounts.push_back(
1150  IntegerAttr::get(shuffledElementType, shiftLeft));
1151  }
1152 
1153  result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
1154  }
1155  return result;
1156 }
1157 
1158 Value BitCastRewriter::genericRewriteStep(
1159  PatternRewriter &rewriter, Location loc, Value initialValue,
1160  Value runningResult, const BitCastRewriter::Metadata &metadata) {
1161  // Create vector.shuffle from the metadata.
1162  auto shuffleOp = rewriter.create<vector::ShuffleOp>(
1163  loc, initialValue, initialValue, metadata.shuffles);
1164 
1165  // Intersect with the mask.
1166  VectorType shuffledVectorType = shuffleOp.getResultVectorType();
1167  auto constOp = rewriter.create<arith::ConstantOp>(
1168  loc, DenseElementsAttr::get(shuffledVectorType, metadata.masks));
1169  Value andValue = rewriter.create<arith::AndIOp>(loc, shuffleOp, constOp);
1170 
1171  // Align right on 0.
1172  auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
1173  loc,
1174  DenseElementsAttr::get(shuffledVectorType, metadata.shiftRightAmounts));
1175  Value shiftedRight =
1176  rewriter.create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
1177 
1178  // Shift bits left into their final position.
1179  auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
1180  loc,
1181  DenseElementsAttr::get(shuffledVectorType, metadata.shiftLeftAmounts));
1182  Value shiftedLeft =
1183  rewriter.create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
1184 
1185  runningResult =
1186  runningResult
1187  ? rewriter.create<arith::OrIOp>(loc, runningResult, shiftedLeft)
1188  : shiftedLeft;
1189 
1190  return runningResult;
1191 }
1192 
1193 /// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
1194 /// bitwise ops that take advantage of high-level information to avoid leaving
1195 /// LLVM to scramble with peephole optimizations.
1197  Value srcValue) {
1198  VectorType srcVecType = cast<VectorType>(srcValue.getType());
1199  assert(srcVecType.getElementType().isSignlessInteger(4) &&
1200  "Expected i4 type");
1201 
1202  // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1203  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
1204  constexpr int64_t i4Toi8BitwidthFactor = 2;
1205  i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
1206  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
1207  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
1208 
1209  // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
1210  // byte are place in one vector and the high i4 elements in another vector.
1211  constexpr int8_t bitsToShift = 4;
1212  auto shiftValues = rewriter.create<arith::ConstantOp>(
1213  loc, DenseElementsAttr::get(i8VecType, bitsToShift));
1214  Value shl = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues);
1215  Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
1216  Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
1217 
1218  // 3. Interleave low and high i8 elements.
1219  return rewriter.create<vector::InterleaveOp>(loc, low, high);
1220 }
1221 
1222 /// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
1223 /// bitwise ops that take advantage of high-level information to avoid leaving
1224 /// LLVM to scramble with peephole optimizations.
1226  Value srcValue) {
1227  VectorType srcVecType = cast<VectorType>(srcValue.getType());
1228  assert(srcVecType.getElementType().isSignlessInteger(4) &&
1229  "Expected i4 type");
1230 
1231  // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1232  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
1233  constexpr int64_t i4Toi8BitwidthFactor = 2;
1234  i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
1235  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
1236  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
1237 
1238  // 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
1239  // byte are placed in one vector and the high i4 elements in another vector.
1240  constexpr uint8_t lowBitsMask = 15; // Equivalent to [00001111] bit mask
1241  auto lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
1242  loc, DenseElementsAttr::get(i8VecType, lowBitsMask));
1243  Value low = rewriter.create<arith::AndIOp>(loc, i8VecType, i8Vector,
1244  lowBitsMaskValues);
1245  constexpr int8_t highBitsToShift = 4;
1246  auto highShiftValues = rewriter.create<arith::ConstantOp>(
1247  loc, DenseElementsAttr::get(i8VecType, highBitsToShift));
1248  Value high = rewriter.create<arith::ShRUIOp>(loc, i8Vector, highShiftValues);
1249 
1250  // 3. Interleave low and high i8 elements.
1251  return rewriter.create<vector::InterleaveOp>(loc, low, high);
1252 }
1253 
1254 /// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
1255 /// ops that take advantage of high-level information to avoid leaving LLVM to
1256 /// scramble with peephole optimizations.
1258  Value srcValue) {
1259  VectorType srcVecType = cast<VectorType>(srcValue.getType());
1260  assert(srcVecType.getElementType().isSignlessInteger(8) &&
1261  "Expected i8 type");
1262 
1263  // 1. De-interleave low and high i8 elements.
1264  auto deinterleaveOp = rewriter.create<vector::DeinterleaveOp>(loc, srcValue);
1265 
1266  // 2. Zero out the upper side of each low i8 element.
1267  constexpr int8_t i8LowBitMask = 0x0F;
1268  VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
1269  Value zeroOutMask = rewriter.create<arith::ConstantOp>(
1270  loc, DenseElementsAttr::get(deinterI8VecType, i8LowBitMask));
1271  Value zeroOutLow = rewriter.create<arith::AndIOp>(
1272  loc, deinterleaveOp.getRes1(), zeroOutMask);
1273 
1274  // 3. Move high i4 values to upper side of the byte.
1275  constexpr int8_t bitsToShift = 4;
1276  auto shiftValues = rewriter.create<arith::ConstantOp>(
1277  loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
1278  Value shlHigh = rewriter.create<arith::ShLIOp>(loc, deinterleaveOp.getRes2(),
1279  shiftValues);
1280 
1281  // 4. Merge high and low i4 values.
1282  auto mergedHiLowOp = rewriter.create<arith::OrIOp>(loc, zeroOutLow, shlHigh);
1283 
1284  // 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
1285  auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());
1286  return rewriter.create<vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
1287 }
1288 
1289 namespace {
1290 /// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
1291 /// advantage of high-level information to avoid leaving LLVM to scramble with
1292 /// peephole optimizations.
1293 struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
1295 
1296  LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
1297  PatternRewriter &rewriter) const override {
1298  // The source must be a trunc op.
1299  auto truncOp =
1300  bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
1301  if (!truncOp)
1302  return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
1303 
1304  // Set up the BitCastRewriter and verify the precondition.
1305  VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1306  VectorType targetVectorType = bitCastOp.getResultVectorType();
1307  BitCastRewriter bcr(sourceVectorType, targetVectorType);
1308  if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
1309  return failure();
1310 
1311  // Perform the rewrite.
1312  Value truncValue = truncOp.getIn();
1313  auto shuffledElementType =
1314  cast<IntegerType>(getElementTypeOrSelf(truncValue.getType()));
1315  Value runningResult;
1316  for (const BitCastRewriter ::Metadata &metadata :
1317  bcr.precomputeMetadata(shuffledElementType)) {
1318  runningResult = bcr.genericRewriteStep(
1319  rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
1320  }
1321 
1322  // Finalize the rewrite.
1323  bool narrowing = targetVectorType.getElementTypeBitWidth() <=
1324  shuffledElementType.getIntOrFloatBitWidth();
1325  if (narrowing) {
1326  if (runningResult.getType() == bitCastOp.getResultVectorType()) {
1327  rewriter.replaceOp(bitCastOp, runningResult);
1328  } else {
1329  rewriter.replaceOpWithNewOp<arith::TruncIOp>(
1330  bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1331  }
1332  } else {
1333  if (runningResult.getType() == bitCastOp.getResultVectorType()) {
1334  rewriter.replaceOp(bitCastOp, runningResult);
1335  } else {
1336  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
1337  bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1338  }
1339  }
1340 
1341  return success();
1342  }
1343 };
1344 } // namespace
1345 
1346 //===----------------------------------------------------------------------===//
1347 // RewriteExtOfBitCast
1348 //===----------------------------------------------------------------------===//
1349 
1350 namespace {
1351 /// Rewrite ext{s,u}i(bitcast) to a sequence of shuffles and bitwise ops that
1352 /// take advantage of high-level information to avoid leaving LLVM to scramble
1353 /// with peephole optimizations.
1354 template <typename ExtOpType>
1355 struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
1357 
1358  RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
1359  : OpRewritePattern<ExtOpType>(context, benefit) {}
1360 
1361  LogicalResult matchAndRewrite(ExtOpType extOp,
1362  PatternRewriter &rewriter) const override {
1363  // The source must be a bitcast op.
1364  auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
1365  if (!bitCastOp)
1366  return rewriter.notifyMatchFailure(extOp, "not a bitcast source");
1367 
1368  // Set up the BitCastRewriter and verify the precondition.
1369  VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1370  VectorType targetVectorType = bitCastOp.getResultVectorType();
1371  BitCastRewriter bcr(sourceVectorType, targetVectorType);
1372  if (failed(bcr.commonPrecondition(
1373  rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
1374  return failure();
1375 
1376  // Perform the rewrite.
1377  Value runningResult;
1378  Value sourceValue = bitCastOp.getSource();
1379  auto shuffledElementType =
1380  cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
1381  for (const BitCastRewriter::Metadata &metadata :
1382  bcr.precomputeMetadata(shuffledElementType)) {
1383  runningResult = bcr.genericRewriteStep(
1384  rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
1385  }
1386 
1387  // Finalize the rewrite.
1388  bool narrowing =
1389  cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
1390  shuffledElementType.getIntOrFloatBitWidth();
1391  if (narrowing) {
1392  rewriter.replaceOpWithNewOp<arith::TruncIOp>(
1393  extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1394  } else {
1395  rewriter.replaceOpWithNewOp<ExtOpType>(
1396  extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
1397  }
1398 
1399  return success();
1400  }
1401 };
1402 
1403 /// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
1404 /// bitwise ops that take advantage of high-level information to avoid leaving
1405 /// LLVM to scramble with peephole optimizations. Templated to choose between
1406 /// signed and unsigned conversions.
1407 ///
1408 /// For example (signed):
1409 /// arith.extsi %in : vector<8xi4> to vector<8xi32>
1410 /// is rewriten as
1411 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1412 /// %1 = arith.shli %0, 4 : vector<4xi8>
1413 /// %2 = arith.shrsi %1, 4 : vector<4xi8>
1414 /// %3 = arith.shrsi %0, 4 : vector<4xi8>
1415 /// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1416 /// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
1417 ///
1418 /// arith.sitofp %in : vector<8xi4> to vector<8xf32>
1419 /// is rewriten as
1420 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1421 /// %1 = arith.shli %0, 4 : vector<4xi8>
1422 /// %2 = arith.shrsi %1, 4 : vector<4xi8>
1423 /// %3 = arith.shrsi %0, 4 : vector<4xi8>
1424 /// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
1425 /// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
1426 ///
1427 /// Example (unsigned):
1428 /// arith.extui %in : vector<8xi4> to vector<8xi32>
1429 /// is rewritten as
1430 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1431 /// %1 = arith.andi %0, 15 : vector<4xi8>
1432 /// %2 = arith.shrui %0, 4 : vector<4xi8>
1433 /// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
1434 /// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
1435 ///
1436 template <typename ConversionOpType, bool isSigned>
1437 struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
1439 
1440  LogicalResult matchAndRewrite(ConversionOpType conversionOp,
1441  PatternRewriter &rewriter) const override {
1442  // Verify the preconditions.
1443  Value srcValue = conversionOp.getIn();
1444  auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
1445  auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
1446 
1447  if (failed(
1448  commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
1449  return failure();
1450 
1451  // Check general alignment preconditions.
1452  if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
1453  conversionOp)))
1454  return failure();
1455 
1456  // Perform the rewrite.
1457  Value subByteExt;
1458  if (isSigned) {
1459  subByteExt =
1460  rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
1461  } else {
1462  subByteExt =
1463  rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
1464  }
1465 
1466  // Finalize the rewrite.
1467  rewriter.replaceOpWithNewOp<ConversionOpType>(
1468  conversionOp, conversionOp.getType(), subByteExt);
1469  return success();
1470  }
1471 };
1472 
1473 /// Rewrite the i8 -> i4 part of any truncation into a deinterleave and
1474 /// bitwise ops that take advantage of high-level information to avoid leaving
1475 /// LLVM to scramble with peephole optimizations.
1476 ///
1477 /// For example:
1478 /// arith.trunci %in : vector<8xi32> to vector<8xi4>
1479 /// is rewriten as
1480 ///
1481 /// %cst = arith.constant dense<15> : vector<4xi8>
1482 /// %cst_0 = arith.constant dense<4> : vector<4xi8>
1483 /// %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
1484 /// %2 = arith.andi %0, %cst : vector<4xi8>
1485 /// %3 = arith.shli %1, %cst_0 : vector<4xi8>
1486 /// %4 = arith.ori %2, %3 : vector<4xi8>
1487 /// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
1488 ///
1489 struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
1491 
1492  LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
1493  PatternRewriter &rewriter) const override {
1494  // Verify the preconditions.
1495  Value srcValue = truncOp.getIn();
1496  auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
1497  auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
1498  if (!srcVecType || !dstVecType)
1499  return failure();
1500 
1501  if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
1502  return failure();
1503 
1504  // Check general alignment preconditions. We invert the src/dst type order
1505  // to reuse the existing precondition logic.
1506  if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
1507  truncOp)))
1508  return failure();
1509 
1510  // Create a new iX -> i8 truncation op.
1511  Location loc = truncOp.getLoc();
1512  auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
1513  Value i8TruncVal =
1514  rewriter.create<arith::TruncIOp>(loc, i8VecType, srcValue);
1515 
1516  // Rewrite the i8 -> i4 truncation part.
1517  Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal);
1518 
1519  // Finalize the rewrite.
1520  rewriter.replaceOp(truncOp, subByteTrunc);
1521  return success();
1522  }
1523 };
1524 
1525 /// Rewrite a sub-byte vector transpose into a sequence of instructions that
1526 /// perform the transpose on wider (byte) element types.
1527 /// For example:
1528 /// %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
1529 ///
1530 /// is rewritten as:
1531 ///
1532 /// %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
1533 /// %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
1534 /// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4>
1535 ///
1536 struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
1538 
1539  RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
1540  : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
1541 
1542  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
1543  PatternRewriter &rewriter) const override {
1544  // Precondition: sub-byte integer transpose.
1545  constexpr unsigned minNativeBitwidth = 8;
1546  VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
1547  if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
1548  srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
1549  return rewriter.notifyMatchFailure(transposeOp,
1550  "not a sub-byte transpose");
1551  }
1552 
1553  // Perform the rewrite.
1554  Location loc = transposeOp.getLoc();
1555  // Signed/unsigned interpretation shouldn't matter here as we are just
1556  // transposing the elements and truncating them back to the original size.
1557  // TODO: Use unsigned extension (more efficient) when emulation or backend
1558  // support is available.
1559  auto srcNativeVecType = srcSubByteVecType.cloneWith(
1560  std::nullopt, rewriter.getIntegerType(minNativeBitwidth));
1561  Value extOp = rewriter.create<arith::ExtSIOp>(loc, srcNativeVecType,
1562  transposeOp.getVector());
1563  Value newTranspose = rewriter.create<vector::TransposeOp>(
1564  loc, extOp, transposeOp.getPermutation());
1565  VectorType dstSubByteVecType = transposeOp.getResultVectorType();
1566  rewriter.replaceOpWithNewOp<arith::TruncIOp>(transposeOp, dstSubByteVecType,
1567  newTranspose);
1568  return success();
1569  }
1570 };
1571 
1572 } // namespace
1573 
1574 //===----------------------------------------------------------------------===//
1575 // Public Interface Definition
1576 //===----------------------------------------------------------------------===//
1577 
1579  const arith::NarrowTypeEmulationConverter &typeConverter,
1580  RewritePatternSet &patterns) {
1581 
1582  // Populate `vector.*` conversion patterns.
1583  patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
1584  ConvertVectorMaskedStore, ConvertVectorTransferRead>(
1585  typeConverter, patterns.getContext());
1586 }
1587 
1589  RewritePatternSet &patterns, PatternBenefit benefit) {
1590  patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
1591  RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
1592  benefit);
1593 
1594  // Patterns for aligned cases. We set higher priority as they are expected to
1595  // generate better performance for aligned cases.
1596  patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, /*isSigned=*/true>,
1597  RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
1598  RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
1599  benefit.getBenefit() + 1);
1600  patterns
1601  .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>,
1602  RewriteAlignedSubByteIntExt<arith::UIToFPOp, /*isSigned=*/false>>(
1603  patterns.getContext(), benefit.getBenefit() + 1);
1604 }
1605 
1607  RewritePatternSet &patterns, PatternBenefit benefit) {
1608  patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
1609 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, TypedValue< VectorType > source, Value dest, OpFoldResult offset, int64_t numElementsToExtract)
Extracts a 1-D subvector from a 1-D source vector, with index at offset and size numElementsToExtract...
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops that take advantage of ...
static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i4 -> i8 signed extension into a sequence of shuffles and bitwise ops that take advantage...
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, Value src, Value dest, int64_t offset)
Inserts 1-D subvector into a 1-D vector by overwriting the elements starting at offset.
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, VectorType extractType, Value source, int64_t frontOffset, int64_t subvecSize)
Extracts 1-D subvector from a 1-D vector.
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int numSrcElems, int numSrcElemsPerDest, int numFrontPadElems=0)
Returns a compressed mask for the emulated vector.
static TypedValue< VectorType > emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, OpFoldResult linearizedIndices, int64_t numEmultedElementsToLoad, Type origElemType, Type emulatedElemType)
Returns the op sequence for an emulated sub-byte data type vector load.
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, TypedValue< VectorType > source, Value dest, OpFoldResult destOffsetVar, size_t length)
Inserts a 1-D subvector into a 1-D dest vector at index destOffsetVar.
#define LDBG(X)
static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and bitwise ops that take advanta...
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType srcType, VectorType dstType, Operation *op)
Verify that source and destination element types meet the precondition for the supported aligned conv...
Base type for affine expression.
Definition: AffineExpr.h:68
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
IntegerType getI4Type()
Definition: Builders.cpp:101
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:97
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:321
IndexType getIndexType()
Definition: Builders.cpp:95
IntegerType getI8Type()
Definition: Builders.cpp:103
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:215
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_type_range getResultTypes()
Definition: Operation.h:423
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Converts narrow integer or float types that are not supported by the target hardware to wider types.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1194
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Definition: MemRefUtils.cpp:52
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for rewriting vector operations over narrow types with ops over wider types.
void populateVectorNarrowTypeEmulationPatterns(const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns)
Appends patterns for emulating vector operations over narrow types with ops over wider types.
void populateVectorTransposeNarrowTypeRewritePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Appends patterns for emulating a sub-byte vector transpose.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:329
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
Definition: MemRefUtils.h:50
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.