MLIR  22.0.0git
VectorEmulateNarrowType.cpp
Go to the documentation of this file.
1 //===- VectorEmulateNarrowType.cpp - Narrow type emulation ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to emulate
10 // narrow types that are not supported by the target hardware, e.g. i4
11 // ("emulated type"), using wider types, e.g. i8 ("container type").
12 //
13 /// Currently, only power-of-two integer types are supported. These are
14 /// converted to wider integers that are either 8 bits wide or wider.
15 ///
16 /// TODO: Support for non-powers-of-two.
17 //===----------------------------------------------------------------------===//
18 
29 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/IR/OpDefinition.h"
31 #include "mlir/IR/TypeUtilities.h"
32 #include "mlir/IR/Value.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/Support/DebugLog.h"
36 #include "llvm/Support/MathExtras.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include <cstdint>
39 #include <optional>
40 
42 
43 using namespace mlir;
44 
45 #define DEBUG_TYPE "vector-narrow-type-emulation"
46 
49 
50 //===----------------------------------------------------------------------===//
51 // Utils
52 //===----------------------------------------------------------------------===//
53 
54 /// Returns a compressed mask for the emulated vector. For example, when
55 /// emulating an eight-element `i8` vector with `i32` (i.e. when the source
56 /// elements span two dest elements), this method compresses `vector<8xi1>`
57 /// into `vector<2xi1>`.
58 ///
59 /// The compressed/output mask value is set iff any mask in the corresponding
60 /// `numSrcElemsPerDest` range of uncompressed/input masks is set. E.g., if
61 /// `numSrcElemsPerDest` equals to 2, and `numFrontPadElems` equals to 1, the
62 /// following mask:
63 ///
64 /// %mask = [1, 1, 0, 0, 0, 0]
65 ///
66 /// will first be padded in the front with `numFrontPadElems` zeros, and zeros
67 /// will be added in the back to make the number of elements a multiple of
68 /// `numSrcElemsPerDest` (for easier computation). The resulting mask will be:
69 ///
70 /// %mask = [0, 1, 1, 0, 0, 0, 0, 0]
71 ///
72 /// then it will return the following new compressed mask:
73 ///
74 /// %mask = [1, 1, 0, 0]
75 ///
76 /// NOTE: `numFrontPadElems` is assumed to be strictly smaller than
77 /// `numSrcElemsPerDest`.
78 static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
79  Location loc, Value mask,
80  int numSrcElems,
81  int numSrcElemsPerDest,
82  int numFrontPadElems = 0) {
83 
84  assert(numFrontPadElems < numSrcElemsPerDest &&
85  "numFrontPadElems must be less than numSrcElemsPerDest");
86 
87  auto numDestElems =
88  (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
89  numSrcElemsPerDest;
90 
91  Operation *maskOp = mask.getDefiningOp();
93  // TODO: add support to `vector.splat`.
94  // Finding the mask creation operation.
95  while (maskOp &&
96  !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
97  maskOp)) {
98  if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
99  maskOp = extractOp.getSource().getDefiningOp();
100  extractOps.push_back(extractOp);
101  }
102  }
103 
104  if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
105  maskOp))
106  return failure();
107 
108  // Computing the "compressed" mask. All the emulation logic (i.e. computing
109  // new mask index) only happens on the last dimension of the vectors.
110  SmallVector<int64_t> maskShape(
111  cast<VectorType>(maskOp->getResultTypes()[0]).getShape());
112  maskShape.back() = numDestElems;
113  auto newMaskType = VectorType::get(maskShape, rewriter.getI1Type());
114  std::optional<Operation *> newMask =
116  .Case<vector::CreateMaskOp>(
117  [&](auto createMaskOp) -> std::optional<Operation *> {
118  OperandRange maskOperands = createMaskOp.getOperands();
119  // The `vector.create_mask` op creates a mask arrangement
120  // without any zeros at the front. Also, because
121  // `numFrontPadElems` is strictly smaller than
122  // `numSrcElemsPerDest`, the compressed mask generated by
123  // padding the original mask by `numFrontPadElems` will not
124  // have any zeros at the front as well.
125  AffineExpr s0;
126  bindSymbols(rewriter.getContext(), s0);
127  s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
128  OpFoldResult origIndex = getAsOpFoldResult(maskOperands.back());
130  rewriter, loc, s0, origIndex);
131  SmallVector<Value> newMaskOperands(maskOperands.drop_back());
132  newMaskOperands.push_back(
133  getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
134  return vector::CreateMaskOp::create(rewriter, loc, newMaskType,
135  newMaskOperands);
136  })
137  .Case<vector::ConstantMaskOp>([&](auto constantMaskOp)
138  -> std::optional<Operation *> {
139  // Take the shape of mask, compress its trailing dimension:
140  SmallVector<int64_t> maskDimSizes(constantMaskOp.getMaskDimSizes());
141  int64_t &maskIndex = maskDimSizes.back();
142  maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex,
143  numSrcElemsPerDest);
144  return vector::ConstantMaskOp::create(rewriter, loc, newMaskType,
145  maskDimSizes);
146  })
147  .Case<arith::ConstantOp>([&](auto constantOp)
148  -> std::optional<Operation *> {
149  // TODO: Support multiple dimensions.
150  if (maskShape.size() != 1)
151  return std::nullopt;
152  // Rearrange the original mask values to cover the whole potential
153  // loading region. For example, in the case of using byte-size for
154  // emulation, given the following mask:
155  //
156  // %mask = [0, 1, 0, 1, 0, 0]
157  //
158  // With front offset of 1, the mask will be padded 0s in the front
159  // and back so that:
160  // 1. It is aligned with the effective loading bits
161  // 2. Its length is multiple of `numSrcElemPerDest` (and the total
162  // coverage size is mulitiple of bytes). The new mask will be like
163  // this before compressing:
164  //
165  // %new_mask = [0, 0, 1, 0, 1, 0, 0, 0]
166  auto originalMask =
167  cast<DenseIntElementsAttr>(constantOp.getValue());
168  SmallVector<bool> paddedMaskValues(numFrontPadElems, false);
169  paddedMaskValues.append(originalMask.template value_begin<bool>(),
170  originalMask.template value_end<bool>());
171  paddedMaskValues.resize(numDestElems * numSrcElemsPerDest, false);
172 
173  // Compressing by combining every `numSrcElemsPerDest` elements:
174  SmallVector<bool> compressedMaskValues;
175  for (size_t i = 0; i < paddedMaskValues.size();
176  i += numSrcElemsPerDest) {
177  bool combinedValue = false;
178  for (int j = 0; j < numSrcElemsPerDest; ++j) {
179  combinedValue |= paddedMaskValues[i + j];
180  }
181  compressedMaskValues.push_back(combinedValue);
182  }
183  return arith::ConstantOp::create(
184  rewriter, loc,
185  DenseElementsAttr::get(newMaskType, compressedMaskValues));
186  });
187 
188  if (!newMask)
189  return failure();
190 
191  while (!extractOps.empty()) {
192  newMask =
193  vector::ExtractOp::create(rewriter, loc, (*newMask)->getResults()[0],
194  extractOps.back().getMixedPosition());
195  extractOps.pop_back();
196  }
197 
198  return *newMask;
199 }
200 
201 /// Extracts 1-D subvector from a 1-D vector.
202 ///
203 /// Given the input rank-1 source vector, extracts `numElemsToExtract` elements
204 /// from `src`, starting at `offset`. The result is also a rank-1 vector:
205 ///
206 /// vector<numElemsToExtract x !elemType>
207 ///
208 /// (`!elType` is the element type of the source vector). As `offset` is a known
209 /// _static_ value, this helper hook emits `vector.extract_strided_slice`.
210 ///
211 /// EXAMPLE:
212 /// %res = vector.extract_strided_slice %src
213 /// { offsets = [offset], sizes = [numElemsToExtract], strides = [1] }
215  Value src, int64_t offset,
216  int64_t numElemsToExtract) {
217  auto vectorType = cast<VectorType>(src.getType());
218  assert(vectorType.getRank() == 1 && "expected source to be rank-1-D vector ");
219  assert(offset + numElemsToExtract <= vectorType.getNumElements() &&
220  "subvector out of bounds");
221 
222  // When extracting all available elements, just use the source vector as the
223  // result.
224  if (vectorType.getNumElements() == numElemsToExtract)
225  return src;
226 
227  auto offsets = rewriter.getI64ArrayAttr({offset});
228  auto sizes = rewriter.getI64ArrayAttr({numElemsToExtract});
229  auto strides = rewriter.getI64ArrayAttr({1});
230 
231  auto resultVectorType =
232  VectorType::get({numElemsToExtract}, vectorType.getElementType());
233  return vector::ExtractStridedSliceOp::create(rewriter, loc, resultVectorType,
234  src, offsets, sizes, strides)
235  ->getResult(0);
236 }
237 
238 /// Inserts 1-D subvector into a 1-D vector.
239 ///
240 /// Inserts the input rank-1 source vector into the destination vector starting
241 /// at `offset`. As `offset` is a known _static_ value, this helper hook emits
242 /// `vector.insert_strided_slice`.
243 ///
244 /// EXAMPLE:
245 /// %res = vector.insert_strided_slice %src, %dest
246 /// {offsets = [%offset], strides [1]}
248  Value src, Value dest, int64_t offset) {
249  [[maybe_unused]] auto srcVecTy = cast<VectorType>(src.getType());
250  [[maybe_unused]] auto destVecTy = cast<VectorType>(dest.getType());
251  assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
252  "expected source and dest to be rank-1 vector types");
253 
254  // If overwritting the destination vector, just return the source.
255  if (srcVecTy.getNumElements() == destVecTy.getNumElements() && offset == 0)
256  return src;
257 
258  auto offsets = rewriter.getI64ArrayAttr({offset});
259  auto strides = rewriter.getI64ArrayAttr({1});
260  return vector::InsertStridedSliceOp::create(rewriter, loc, destVecTy, src,
261  dest, offsets, strides);
262 }
263 
264 /// Extracts 1-D subvector from a 1-D vector.
265 ///
266 /// Given the input rank-1 source vector, extracts `numElemsToExtact` elements
267 /// from `src`, starting at `offset`. The result is also a rank-1 vector:
268 ///
269 /// vector<numElemsToExtact x !elType>
270 ///
271 /// (`!elType` is the element type of the source vector). As `offset` is assumed
272 /// to be a _dynamic_ SSA value, this helper method generates a sequence of
273 /// `vector.extract` + `vector.insert` pairs.
274 ///
275 /// EXAMPLE:
276 /// %v1 = vector.extract %src[%offset] : i2 from vector<8xi2>
277 /// %r1 = vector.insert %v1, %dest[0] : i2 into vector<3xi2>
278 /// %c1 = arith.constant 1 : index
279 /// %idx2 = arith.addi %offset, %c1 : index
280 /// %v2 = vector.extract %src[%idx2] : i2 from vector<8xi2>
281 /// %r2 = vector.insert %v2, %r1 [1] : i2 into vector<3xi2>
282 /// (...)
284  Value src, Value dest,
285  OpFoldResult offset,
286  int64_t numElemsToExtract) {
287  auto srcVecTy = cast<VectorType>(src.getType());
288  assert(srcVecTy.getRank() == 1 && "expected source to be rank-1-D vector ");
289  // NOTE: We are unable to take the offset into account in the following
290  // assert, hence its still possible that the subvector is out-of-bounds even
291  // if the condition is true.
292  assert(numElemsToExtract <= srcVecTy.getNumElements() &&
293  "subvector out of bounds");
294 
295  // When extracting all available elements, just use the source vector as the
296  // result.
297  if (srcVecTy.getNumElements() == numElemsToExtract)
298  return src;
299 
300  for (int i = 0; i < numElemsToExtract; ++i) {
301  Value extractLoc =
302  (i == 0) ? dyn_cast<Value>(offset)
303  : arith::AddIOp::create(
304  rewriter, loc, rewriter.getIndexType(),
305  dyn_cast<Value>(offset),
306  arith::ConstantIndexOp::create(rewriter, loc, i));
307  auto extractOp = vector::ExtractOp::create(rewriter, loc, src, extractLoc);
308  dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, i);
309  }
310  return dest;
311 }
312 
313 /// Inserts 1-D subvector into a 1-D vector.
314 ///
315 /// Inserts the input rank-1 source vector into the destination vector starting
316 /// at `offset`. As `offset` is assumed to be a _dynamic_ SSA value, this hook
317 /// uses a sequence of `vector.extract` + `vector.insert` pairs.
318 ///
319 /// EXAMPLE:
320 /// %v1 = vector.extract %src[0] : i2 from vector<8xi2>
321 /// %r1 = vector.insert %v1, %dest[%offset] : i2 into vector<3xi2>
322 /// %c1 = arith.constant 1 : index
323 /// %idx2 = arith.addi %offset, %c1 : index
324 /// %v2 = vector.extract %src[1] : i2 from vector<8xi2>
325 /// %r2 = vector.insert %v2, %r1 [%idx2] : i2 into vector<3xi2>
326 /// (...)
328  Value src, Value dest,
329  OpFoldResult offset,
330  int64_t numElemsToInsert) {
331  auto srcVecTy = cast<VectorType>(src.getType());
332  auto destVecTy = cast<VectorType>(dest.getType());
333  assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
334  "expected source and dest to be rank-1 vector types");
335  (void)srcVecTy;
336  (void)destVecTy;
337  assert(numElemsToInsert > 0 &&
338  "the number of elements to insert must be greater than 0");
339  // NOTE: We are unable to take the offset into account in the following
340  // assert, hence its still possible that the subvector is out-of-bounds even
341  // if the condition is true.
342  assert(numElemsToInsert <= destVecTy.getNumElements() &&
343  "subvector out of bounds");
344 
345  Value destOffsetVal = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
346  for (int64_t i = 0; i < numElemsToInsert; ++i) {
347  auto insertLoc =
348  i == 0 ? destOffsetVal
349  : arith::AddIOp::create(
350  rewriter, loc, rewriter.getIndexType(), destOffsetVal,
351  arith::ConstantIndexOp::create(rewriter, loc, i));
352  auto extractOp = vector::ExtractOp::create(rewriter, loc, src, i);
353  dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, insertLoc);
354  }
355  return dest;
356 }
357 
358 /// Emulate a vector load for `emulatedElemTy` using `containerElemTy`
359 ///
360 /// Specifically, use `containerElemTy` for loading a vector of
361 /// `emulatedElemTy`. The load location is given by `base` and
362 /// `linearizedIndices`, and the load size is given by
363 /// `numEmulatedElementsToLoad`.
365  Value base,
366  OpFoldResult linearizedIndices,
367  int64_t numContainerElemsToLoad,
368  Type emulatedElemTy,
369  Type containerElemTy) {
370  auto emulatedPerContainerElem = containerElemTy.getIntOrFloatBitWidth() /
371  emulatedElemTy.getIntOrFloatBitWidth();
372  auto newLoad = vector::LoadOp::create(
373  rewriter, loc, VectorType::get(numContainerElemsToLoad, containerElemTy),
374  base, getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
375  return vector::BitCastOp::create(
376  rewriter, loc,
377  VectorType::get(numContainerElemsToLoad * emulatedPerContainerElem,
378  emulatedElemTy),
379  newLoad);
380 }
381 
382 /// Downcast two values to `downcastType`, then select values
383 /// based on `mask`, and casts the result to `upcastType`.
385  VectorType downcastType,
386  VectorType upcastType, Value mask,
387  Value trueValue, Value falseValue) {
388  assert(
389  downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==
390  upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
391  "expected input and output number of bits to match");
392  if (trueValue.getType() != downcastType) {
393  trueValue =
394  vector::BitCastOp::create(builder, loc, downcastType, trueValue);
395  }
396  if (falseValue.getType() != downcastType) {
397  falseValue =
398  vector::BitCastOp::create(builder, loc, downcastType, falseValue);
399  }
400  Value selectedType =
401  arith::SelectOp::create(builder, loc, mask, trueValue, falseValue);
402  // Upcast the selected value to the new type.
403  return vector::BitCastOp::create(builder, loc, upcastType, selectedType);
404 }
405 
406 /// Emits `memref.generic_atomic_rmw` op to store a subbyte-sized value to a
407 /// byte in `linearizedMemref`, with a mask. The `valueToStore` is a vector of
408 /// subbyte-sized elements, with size of 8 bits, and the mask is used to select
409 /// which elements to store.
410 ///
411 /// Inputs:
412 /// linearizedMemref = |2|2|2|2| : <4xi2> (<1xi8>)
413 /// storeIdx = 2
414 /// valueToStore = |3|3|3|3| : vector<4xi2>
415 /// mask = |0|0|1|1| : vector<4xi1>
416 ///
417 /// Result:
418 /// linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
419 static void atomicRMW(OpBuilder &builder, Location loc,
420  MemRefValue linearizedMemref, Value storeIdx,
421  VectorValue valueToStore, Value mask) {
422  assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
423 
424  // Create an atomic load-modify-write region using
425  // `memref.generic_atomic_rmw`.
426  auto atomicOp = memref::GenericAtomicRMWOp::create(
427  builder, loc, linearizedMemref, ValueRange{storeIdx});
428  Value origValue = atomicOp.getCurrentValue();
429 
430  OpBuilder::InsertionGuard guard(builder);
431  builder.setInsertionPointToStart(atomicOp.getBody());
432 
433  // Load the original value from memory, and cast it to the original element
434  // type.
435  auto oneElemVecType = VectorType::get({1}, origValue.getType());
436  Value origVecValue = vector::FromElementsOp::create(
437  builder, loc, oneElemVecType, ValueRange{origValue});
438 
439  // Construct the final masked value and yield it.
440  Value maskedValue =
441  downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
442  oneElemVecType, mask, valueToStore, origVecValue);
443  auto scalarMaskedValue =
444  vector::ExtractOp::create(builder, loc, maskedValue, 0);
445  memref::AtomicYieldOp::create(builder, loc, scalarMaskedValue);
446 }
447 
448 /// Generate a non-atomic read-modify-write sequence for storing to the emulated
449 /// type. It has similar logic to `atomicRMWStore`, but without atomicity.
450 static void nonAtomicRMW(OpBuilder &builder, Location loc,
451  MemRefValue linearizedMemref, Value linearizedIndex,
452  VectorValue valueToStore, Value mask) {
453  assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
454 
455  auto oneElemVecType =
456  VectorType::get({1}, linearizedMemref.getType().getElementType());
457  Value origVecValue =
458  vector::LoadOp::create(builder, loc, oneElemVecType, linearizedMemref,
459  ValueRange{linearizedIndex});
460  origVecValue = vector::BitCastOp::create(builder, loc, valueToStore.getType(),
461  origVecValue);
462 
463  Value maskedValue =
464  downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
465  oneElemVecType, mask, valueToStore, origVecValue);
466  vector::StoreOp::create(builder, loc, maskedValue, linearizedMemref,
467  linearizedIndex);
468 }
469 
470 /// Extract `sliceNumElements` from source `vector` at `extractOffset`,
471 /// and insert it into an empty vector at `insertOffset`.
472 /// Inputs:
473 /// vec_in = |0|1|2|3| : vector<4xi2>
474 /// extractOffset = 1
475 /// sliceNumElements = 2
476 /// insertOffset = 2
477 /// Output:
478 /// vec_out = |0|0|1|2| : vector<4xi2>
480  Location loc, VectorValue vector,
481  int64_t extractOffset,
482  int64_t sliceNumElements,
483  int64_t insertOffset) {
484  assert(vector.getType().getRank() == 1 && "expected 1-D vector");
485  auto vectorElementType = vector.getType().getElementType();
486  // TODO: update and use `alignedConversionPrecondition` in the place of
487  // these asserts.
488  assert(
489  sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
490  "sliceNumElements * vector element size must be less than or equal to 8");
491  assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
492  "vector element must be a valid sub-byte type");
493  auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth();
494  auto emptyByteVector = arith::ConstantOp::create(
495  rewriter, loc,
496  VectorType::get({emulatedPerContainerElem}, vectorElementType),
497  rewriter.getZeroAttr(
498  VectorType::get({emulatedPerContainerElem}, vectorElementType)));
499  auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
500  extractOffset, sliceNumElements);
501  return staticallyInsertSubvector(rewriter, loc, extracted, emptyByteVector,
502  insertOffset);
503 }
504 
505 namespace {
506 
507 //===----------------------------------------------------------------------===//
508 // ConvertVectorStore
509 //===----------------------------------------------------------------------===//
510 
511 // Emulate `vector.store` using a multi-byte container type.
512 //
513 // The container type is obtained through Op adaptor and would normally be
514 // generated via `NarrowTypeEmulationConverter`.
515 //
516 // EXAMPLE 1
517 // (aligned store of i4, emulated using i8 as the container type)
518 //
519 // vector.store %src, %dest[%idx_1, %idx_2] : memref<4x8xi4>, vector<8xi4>
520 //
521 // is rewritten as:
522 //
523 // %src_bitcast = vector.bitcast %src : vector<8xi4> to vector<4xi8>
524 // vector.store %src_bitcast, %dest_bitcast[%idx]
525 // : memref<16xi8>, vector<4xi8>
526 //
527 // EXAMPLE 2
528 // (unaligned store of i2, emulated using i8 as the container type)
529 //
530 // vector.store %src, %dest[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
531 //
532 // The i2 store is emulated through 2 x RMW sequences. The destination i2 memref
533 // is modelled using 3 bytes:
534 //
535 // Byte 0 Byte 1 Byte 2
536 // +----------+----------+----------+
537 // | oooooooo | ooooNNNN | NNoooooo |
538 // +----------+----------+----------+
539 //
540 // N - (N)ew entries (i.e. to be overwritten by vector.store)
541 // o - (o)ld entries (to be preserved)
542 //
543 // For the generated output in the non-atomic case, see:
544 // * @vector_store_i2_const_index_two_partial_stores`
545 // in:
546 // * "vector-emulate-narrow-type-unaligned-non-atomic.mlir".
547 //
548 // NOTE: By default, all RMW sequences are atomic. Set `disableAtomicRMW` to
549 // `false` to generate non-atomic RMW sequences.
550 struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
552 
553  ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW)
554  : OpConversionPattern<vector::StoreOp>(context),
555  disableAtomicRMW(disableAtomicRMW) {}
556 
557  LogicalResult
558  matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
559  ConversionPatternRewriter &rewriter) const override {
560 
561  if (op.getValueToStore().getType().getRank() != 1)
562  return rewriter.notifyMatchFailure(op,
563  "only 1-D vectors are supported ATM");
564 
565  auto loc = op.getLoc();
566 
567  auto valueToStore = cast<VectorValue>(op.getValueToStore());
568  auto containerElemTy =
569  cast<MemRefType>(adaptor.getBase().getType()).getElementType();
570  Type emulatedElemTy = op.getValueToStore().getType().getElementType();
571  int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
572  int containerBits = containerElemTy.getIntOrFloatBitWidth();
573 
574  // Check per-element alignment.
575  if (containerBits % emulatedBits != 0) {
576  return rewriter.notifyMatchFailure(
577  op, "impossible to pack emulated elements into container elements "
578  "(bit-wise misalignment)");
579  }
580  int emulatedPerContainerElem = containerBits / emulatedBits;
581 
582  // Adjust the number of elements to store when emulating narrow types.
583  // Here only the 1-D vector store is considered, and the N-D memref types
584  // should be linearized.
585  // For example, to emulate i4 to i8, the following op:
586  //
587  // vector.store %arg1, %0[%arg2, %arg3] : memref<4x8xi4>, vector<8xi4>
588  //
589  // can be replaced with
590  //
591  // %bitcast = vector.bitcast %arg1 : vector<8xi4> to vector<4xi8>
592  // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
593  // vector<4xi8>
594 
595  auto origElements = valueToStore.getType().getNumElements();
596  // Note, per-element-alignment was already verified above.
597  bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
598  // Do the trailing dim for source and destination match? If yes, then the
599  // corresponding index must be 0.
600  // FIXME: There's no way to tell for dynamic shapes, so we should bail out.
601  // However, that makes some tests fail, so we need to audit first.
602  auto trailingDim = op.getBase().getType().getShape().back();
603  bool trailingDimsMatch =
604  ShapedType::isDynamic(trailingDim) || trailingDim == origElements;
605 
606  auto stridedMetadata =
607  memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
608 
609  // FIXME: ATM, we do not test cases where offsets, sizes, or strides are
610  // non-zero. As such, this is not needed.
611  OpFoldResult linearizedIndices;
612  memref::LinearizedMemRefInfo linearizedInfo;
613  std::tie(linearizedInfo, linearizedIndices) =
615  rewriter, loc, emulatedBits, containerBits,
616  stridedMetadata.getConstifiedMixedOffset(),
617  stridedMetadata.getConstifiedMixedSizes(),
618  stridedMetadata.getConstifiedMixedStrides(),
619  getAsOpFoldResult(adaptor.getIndices()));
620 
621  std::optional<int64_t> foldedNumFrontPadElems =
622  (isDivisibleInSize && trailingDimsMatch)
623  ? 0
624  : getConstantIntValue(linearizedInfo.intraDataOffset);
625 
626  if (!foldedNumFrontPadElems) {
627  return rewriter.notifyMatchFailure(
628  op, "subbyte store emulation: dynamic front padding size is "
629  "not yet implemented");
630  }
631 
632  auto memrefBase = cast<MemRefValue>(adaptor.getBase());
633 
634  // RMWs are not needed when:
635  // * no _partial_ stores are required.
636  // A partial store is defined as a store in which only a part of the
637  // container element is overwritten, e.g.
638  //
639  // Dest before (8 bits)
640  // +----------+
641  // | 11000000 |
642  // +----------+
643  //
644  // Dest after storing 0xF at offset 4 (in bits)
645  // +----------+
646  // | 11001111 |
647  // +----------+
648  //
649  // At a higher level, this translats to:
650  // 1. The source vector size (in bits) is a multiple of byte size.
651  // 2. The address of the store is aligned to the container type width
652  // boundary.
653  //
654  // EXAMPLE 1:
655  // Requires partial store:
656  // vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2>
657  //
658  // EXAMPLE 2:
659  // Does not require a partial store:
660  // vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2>
661  //
662  // TODO: Take linearizedInfo.linearizedOffset into account. This is
663  // currently not needed/used/exercised as all our tests set offset to 0.
664  bool emulationRequiresPartialStores = *foldedNumFrontPadElems != 0;
665 
666  if (!emulationRequiresPartialStores) {
667  // Basic case: storing full bytes.
668  auto numElements = origElements / emulatedPerContainerElem;
669  auto bitCast = vector::BitCastOp::create(
670  rewriter, loc, VectorType::get(numElements, containerElemTy),
671  op.getValueToStore());
672  rewriter.replaceOpWithNewOp<vector::StoreOp>(
673  op, bitCast.getResult(), memrefBase,
674  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
675  return success();
676  }
677 
678  // Next, handle the case when sub-byte read-modify-write
679  // sequences are needed to emulate a vector store.
680  // Here is an example:
681  //
682  // Vector to store: vector<7xi2>
683  // Value to store: 11 11 11 11 11 11 11 (all ones)
684  //
685  // Destination: memref<12xi2>
686  // Store offset: 2 (i.e. 4 bits into the 1st emulated byte).
687  //
688  // Input MLIR: vector.store %val, %dest[%c2] : memref<12xi2>, vector<7xi2>
689  //
690  // Destination memref before:
691  //
692  // Byte 0 Byte 1 Byte 2
693  // +----------+----------+----------+
694  // | 00000000 | 00000000 | 00000000 |
695  // +----------+----------+----------+
696  //
697  // Destination memref after:
698  //
699  // Byte 0 Byte 1 Byte 2
700  // +----------+----------+----------+
701  // | 00001111 | 11111111 | 11000000 |
702  // +----------+----------+----------+
703  //
704  // Note, stores to Byte 1 are "full-width" and hence don't require RMW (no
705  // need for atomicity). Stores to Bytes 0 and Byte 2 are "partial", hence
706  // requiring RMW access (atomicity is required).
707 
708  // The index into the target memref we are storing to.
709  Value currentDestIndex =
710  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
711  // The index into the source vector we are currently processing.
712  auto currentSourceIndex = 0;
713 
714  // Build a mask used for rmw.
715  auto subWidthStoreMaskType =
716  VectorType::get({emulatedPerContainerElem}, rewriter.getI1Type());
717 
718  auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW;
719 
720  // 1. Partial width store for the leading byte.
721  // When the store address is not aligned to emulated width boundary, deal
722  // with the unaligned part so that the rest elements are aligned to width
723  // boundary.
724  auto frontSubWidthStoreElem =
725  (emulatedPerContainerElem - *foldedNumFrontPadElems) %
726  emulatedPerContainerElem;
727  if (frontSubWidthStoreElem > 0) {
728  SmallVector<bool> frontMaskValues(emulatedPerContainerElem, false);
729  if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) {
730  std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
731  origElements, true);
732  frontSubWidthStoreElem = origElements;
733  } else {
734  std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
735  *foldedNumFrontPadElems, true);
736  }
737  auto frontMask = arith::ConstantOp::create(
738  rewriter, loc,
739  DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
740 
741  currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
742  auto value =
743  extractSliceIntoByte(rewriter, loc, valueToStore, 0,
744  frontSubWidthStoreElem, *foldedNumFrontPadElems);
745 
746  storeFunc(rewriter, loc, memrefBase, currentDestIndex,
747  cast<VectorValue>(value), frontMask.getResult());
748  }
749 
750  if (currentSourceIndex >= origElements) {
751  rewriter.eraseOp(op);
752  return success();
753  }
754 
755  // Increment the destination index by 1 to align to the emulated width
756  // boundary.
757  auto constantOne = arith::ConstantIndexOp::create(rewriter, loc, 1);
758  currentDestIndex = arith::AddIOp::create(
759  rewriter, loc, rewriter.getIndexType(), currentDestIndex, constantOne);
760 
761  // 2. Full width store for the inner output bytes.
762  // After the previous step, the store address is aligned to the emulated
763  // width boundary.
764  int64_t fullWidthStoreSize =
765  (origElements - currentSourceIndex) / emulatedPerContainerElem;
766  int64_t numNonFullWidthElements =
767  fullWidthStoreSize * emulatedPerContainerElem;
768  if (fullWidthStoreSize > 0) {
769  auto fullWidthStorePart = staticallyExtractSubvector(
770  rewriter, loc, valueToStore, currentSourceIndex,
771  numNonFullWidthElements);
772 
773  auto originType = cast<VectorType>(fullWidthStorePart.getType());
774  auto memrefElemType = getElementTypeOrSelf(memrefBase.getType());
775  auto storeType = VectorType::get(
776  {originType.getNumElements() / emulatedPerContainerElem},
777  memrefElemType);
778  auto bitCast = vector::BitCastOp::create(rewriter, loc, storeType,
779  fullWidthStorePart);
780  vector::StoreOp::create(rewriter, loc, bitCast.getResult(), memrefBase,
781  currentDestIndex);
782 
783  currentSourceIndex += numNonFullWidthElements;
784  currentDestIndex = arith::AddIOp::create(
785  rewriter, loc, rewriter.getIndexType(), currentDestIndex,
786  arith::ConstantIndexOp::create(rewriter, loc, fullWidthStoreSize));
787  }
788 
789  // 3. Partial width store for the trailing output byte.
790  // It is needed when the residual length is smaller than the emulated width,
791  // which is not covered in step 2 above.
792  auto remainingElements = origElements - currentSourceIndex;
793  if (remainingElements != 0) {
794  auto subWidthStorePart =
795  extractSliceIntoByte(rewriter, loc, cast<VectorValue>(valueToStore),
796  currentSourceIndex, remainingElements, 0);
797 
798  // Generate back mask.
799  auto maskValues = SmallVector<bool>(emulatedPerContainerElem, 0);
800  std::fill_n(maskValues.begin(), remainingElements, 1);
801  auto backMask = arith::ConstantOp::create(
802  rewriter, loc,
803  DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
804 
805  storeFunc(rewriter, loc, memrefBase, currentDestIndex,
806  cast<VectorValue>(subWidthStorePart), backMask.getResult());
807  }
808 
809  rewriter.eraseOp(op);
810  return success();
811  }
812 
813 private:
814  const bool disableAtomicRMW;
815 };
816 
817 //===----------------------------------------------------------------------===//
818 // ConvertVectorMaskedStore
819 //===----------------------------------------------------------------------===//
820 
821 /// Converts `vector.maskedstore` operations on narrow element types to work
822 /// with wider, byte-aligned container types by adjusting the mask and using
823 /// bitcasting.
824 ///
825 /// Example: Storing `vector<6xi4>` is emulated by bitcasting to `vector<3xi8>`
826 /// (each `i8` container element holds two `i4` values) and storing with an
827 /// adjusted mask .
828 struct ConvertVectorMaskedStore final
829  : OpConversionPattern<vector::MaskedStoreOp> {
831 
832  LogicalResult
833  matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
834  ConversionPatternRewriter &rewriter) const override {
835 
836  // Prerequisite: memref in the vector.maskedstore op is flattened into 1-D.
837  if (op.getValueToStore().getType().getRank() != 1)
838  return rewriter.notifyMatchFailure(
839  op, "Memref in vector.maskedstore op must be flattened beforehand.");
840 
841  auto loc = op.getLoc();
842  auto containerElemTy =
843  cast<MemRefType>(adaptor.getBase().getType()).getElementType();
844  Type emulatedElemTy = op.getValueToStore().getType().getElementType();
845  int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
846  int containerBits = containerElemTy.getIntOrFloatBitWidth();
847 
848  // Check per-element alignment.
849  if (containerBits % emulatedBits != 0) {
850  return rewriter.notifyMatchFailure(
851  op, "impossible to pack emulated elements into container elements "
852  "(bit-wise misalignment)");
853  }
854 
855  int emulatedPerContainerElem = containerBits / emulatedBits;
856  int origElements = op.getValueToStore().getType().getNumElements();
857  if (origElements % emulatedPerContainerElem != 0)
858  return failure();
859 
860  auto stridedMetadata =
861  memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
862  OpFoldResult linearizedIndicesOfr;
863  memref::LinearizedMemRefInfo linearizedInfo;
864  std::tie(linearizedInfo, linearizedIndicesOfr) =
866  rewriter, loc, emulatedBits, containerBits,
867  stridedMetadata.getConstifiedMixedOffset(),
868  stridedMetadata.getConstifiedMixedSizes(),
869  stridedMetadata.getConstifiedMixedStrides(),
870  getAsOpFoldResult(adaptor.getIndices()));
871  Value linearizedIndices =
872  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
873 
874  // Load the whole data and use arith.select to handle the corner cases.
875  //
876  // As an example, for this masked store of i4 values:
877  //
878  // vector.maskedstore %0[%c0, %c0], %mask, %val_to_store
879  //
880  // and given these input values:
881  //
882  // %mask = [0, 1, 1, 1, 1, 0, 0, 0] (8 * i1)
883  // %0[%c0, %c0] =
884  // [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4)
885  // %val_to_store =
886  // [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0] (8 * i4)
887  //
888  // we'll have the following i4 output:
889  //
890  // expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x7, 0x8]
891  //
892  // Emulating the above using i8 will give:
893  //
894  // %compressed_mask = [1, 1, 1, 0] (4 * i1)
895  // %maskedload = [0x12, 0x34, 0x56, 0x00] (4 * i8)
896  // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0] (8 * i4)
897  // %select_using_shifted_mask =
898  // [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x0, 0x0] (8 * i4)
899  // %packed_data = [0x1A, 0xBC, 0xD6, 0x00] (4 * i8)
900  //
901  // Using the compressed mask to store %packed_data results in expected
902  // output.
903  //
904  // FIXME: Make an example based on the comment above work (see #115460 for
905  // reproducer).
906  FailureOr<Operation *> newMask = getCompressedMaskOp(
907  rewriter, loc, op.getMask(), origElements, emulatedPerContainerElem);
908  if (failed(newMask))
909  return failure();
910 
911  auto numElements = (origElements + emulatedPerContainerElem - 1) /
912  emulatedPerContainerElem;
913  auto newType = VectorType::get(numElements, containerElemTy);
914  auto passThru = arith::ConstantOp::create(rewriter, loc, newType,
915  rewriter.getZeroAttr(newType));
916 
917  auto newLoad = vector::MaskedLoadOp::create(
918  rewriter, loc, newType, adaptor.getBase(), linearizedIndices,
919  newMask.value()->getResult(0), passThru);
920 
921  auto newBitCastType =
922  VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
923  Value valueToStore =
924  vector::BitCastOp::create(rewriter, loc, newBitCastType, newLoad);
925  valueToStore = arith::SelectOp::create(rewriter, loc, op.getMask(),
926  op.getValueToStore(), valueToStore);
927  valueToStore =
928  vector::BitCastOp::create(rewriter, loc, newType, valueToStore);
929 
930  rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
931  op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
932  valueToStore);
933  return success();
934  }
935 };
936 
937 //===----------------------------------------------------------------------===//
938 // ConvertVectorLoad
939 //===----------------------------------------------------------------------===//
940 
941 /// Converts `vector.load` on narrow element types to work with
942 /// wider, byte-aligned container types by adjusting load sizes and using
943 /// bitcasting.
944 ///
945 /// Example: `vector.load` of `vector<4xi4>` from `memref<3x4xi4>` is emulated
946 /// by loading `vector<2xi8>` from the linearized `memref<6xi8>` (each `i8`
947 /// container holds two `i4` values) and bitcasting back.
948 ///
949 /// There are cases where the number of elements to load is not byte-aligned. In
950 /// those cases, loads are converted to byte-aligned, byte-sized loads and the
951 /// target vector is extracted from the loaded vector.
952 struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
954 
955  LogicalResult
956  matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
957  ConversionPatternRewriter &rewriter) const override {
958  // Prerequisite: memref in the vector.load op is flattened into 1-D.
959  if (op.getVectorType().getRank() != 1)
960  return rewriter.notifyMatchFailure(
961  op, "Memref in emulated vector ops must be flattened beforehand.");
962 
963  auto loc = op.getLoc();
964  auto containerElemTy =
965  cast<MemRefType>(adaptor.getBase().getType()).getElementType();
966  Type emulatedElemTy = op.getType().getElementType();
967  int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
968  int containerBits = containerElemTy.getIntOrFloatBitWidth();
969 
970  // Check per-element alignment.
971  if (containerBits % emulatedBits != 0) {
972  return rewriter.notifyMatchFailure(
973  op, "impossible to pack emulated elements into container elements "
974  "(bit-wise misalignment)");
975  }
976  int emulatedPerContainerElem = containerBits / emulatedBits;
977 
978  // Adjust the number of elements to load when emulating narrow types,
979  // and then cast back to the original type with vector.bitcast op.
980  // For example, to emulate i4 to i8, the following op:
981  //
982  // %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
983  //
984  // can be replaced with
985  //
986  // %1 = vector.load %0[%linear_index] : memref<6xi8>, vector<2xi8>
987  // %2 = vector.bitcast %1 : vector<2xi8> to vector<4xi4>
988  //
989  // There are cases where the number of elements to load is not byte-aligned,
990  // for example:
991  //
992  // %1 = vector.load %0[%c1, %c0] : memref<3x3xi2>, vector<3xi2>
993  //
994  // we will have to load extra bytes and extract the exact slice in between.
995  //
996  // %1 = vector.load %0[%c2] : memref<3xi8>, vector<2xi8>
997  // %2 = vector.bitcast %1 : vector<2xi8> to vector<8xi2>
998  // %3 = vector.extract_strided_slice %1 {offsets = [2], sizes = [3], strides
999  // = [1]}
1000  // : vector<8xi2> to vector<3xi2>
1001  //
1002  // TODO: Currently the extract_strided_slice's attributes must be known at
1003  // compile time as they must be constants.
1004 
1005  auto origElements = op.getVectorType().getNumElements();
1006  // Note, per-element-alignment was already verified above.
1007  bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
1008 
1009  auto stridedMetadata =
1010  memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1011 
1012  OpFoldResult linearizedIndices;
1013  memref::LinearizedMemRefInfo linearizedInfo;
1014  std::tie(linearizedInfo, linearizedIndices) =
1016  rewriter, loc, emulatedBits, containerBits,
1017  stridedMetadata.getConstifiedMixedOffset(),
1018  stridedMetadata.getConstifiedMixedSizes(),
1019  stridedMetadata.getConstifiedMixedStrides(),
1020  getAsOpFoldResult(adaptor.getIndices()));
1021 
1022  std::optional<int64_t> foldedIntraVectorOffset =
1023  isDivisibleInSize ? 0
1024  : getConstantIntValue(linearizedInfo.intraDataOffset);
1025 
1026  // Always load enough elements which can cover the original elements.
1027  int64_t maxintraDataOffset =
1028  foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1029  auto numElements = llvm::divideCeil(maxintraDataOffset + origElements,
1030  emulatedPerContainerElem);
1031  Value result =
1032  emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
1033  numElements, emulatedElemTy, containerElemTy);
1034 
1035  if (!foldedIntraVectorOffset) {
1036  auto resultVector = arith::ConstantOp::create(
1037  rewriter, loc, op.getType(), rewriter.getZeroAttr(op.getType()));
1038  result = dynamicallyExtractSubVector(
1039  rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
1040  linearizedInfo.intraDataOffset, origElements);
1041  } else if (!isDivisibleInSize) {
1042  result = staticallyExtractSubvector(
1043  rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1044  }
1045  rewriter.replaceOp(op, result);
1046  return success();
1047  }
1048 };
1049 
1050 //===----------------------------------------------------------------------===//
1051 // ConvertVectorMaskedLoad
1052 //===----------------------------------------------------------------------===//
1053 
1054 /// Converts `vector.maskedload` operations on narrow element types to work with
1055 /// wider, byte-aligned container types by adjusting the mask and using
1056 /// bitcasting.
1057 ///
1058 /// Example: Loading `vector<6xi4>` is emulated by loading `vector<3xi8>` and
1059 /// bitcasting, since each `i8` container element holds two `i4` values.
1060 struct ConvertVectorMaskedLoad final
1061  : OpConversionPattern<vector::MaskedLoadOp> {
1063 
1064  LogicalResult
1065  matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
1066  ConversionPatternRewriter &rewriter) const override {
1067  if (op.getVectorType().getRank() != 1)
1068  return rewriter.notifyMatchFailure(
1069  op, "Memref in emulated vector ops must be flattened beforehand.");
1070 
1071  auto loc = op.getLoc();
1072 
1073  auto containerElemTy =
1074  cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1075  Type emulatedElemTy = op.getType().getElementType();
1076  int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
1077  int containerBits = containerElemTy.getIntOrFloatBitWidth();
1078 
1079  // Check per-element alignment.
1080  if (containerBits % emulatedBits != 0) {
1081  return rewriter.notifyMatchFailure(
1082  op, "impossible to pack emulated elements into container elements "
1083  "(bit-wise misalignment)");
1084  }
1085  int emulatedPerContainerElem = containerBits / emulatedBits;
1086 
1087  // Adjust the number of elements to load when emulating narrow types,
1088  // and then cast back to the original type with vector.bitcast op.
1089  // For example, to emulate i4 to i8, the following op:
1090  //
1091  // %mask = vector.constant_mask [3] : vector<6xi1>
1092  // %1 = vector.maskedload %0[%c0, %c0], %mask, %pass_thru :
1093  // memref<3x6xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
1094  //
1095  // can be replaced with
1096  //
1097  // %new_mask = vector.constant_mask [2] : vector<3xi1>
1098  // %new_pass_thru = vector.bitcast %pass_thru :
1099  // vector<6xi4> to vector<3xi8>
1100  // %1 = vector.maskedload %0[%linear_index], %new_mask, %new_pass_thru :
1101  // memref<9xi8>, vector<3xi1>, vector<3xi8> into vector<3xi8>
1102  // %2 = vector.bitcast %1 : vector<3xi8> to vector<6xi4>
1103  //
1104  // Since we are effectively loading 16 bits (2xi8) from the memref with the
1105  // new mask, while originally we only wanted to effectively load 12 bits
1106  // (3xi4) from the memref, we need to set the second half of the last i8
1107  // that was effectively loaded (i.e. the second i8) to %pass_thru.
1108  //
1109  // %3 = arith.select %mask, %2, %pass_thru : vector<6xi1>, vector<6xi4>
1110  //
1111  // Given these input values:
1112  // %mask = [1, 1, 1, 0, 0, 0]
1113  // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
1114  // %pass_thru = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
1115  //
1116  // we'll have:
1117  //
1118  // expected output: [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
1119  //
1120  // %new_mask = [1, 1, 0]
1121  // %new_pass_thru = [0x78, 0x9A, 0xBC]
1122  // %1 = [0x12, 0x34, 0xBC]
1123  // %2 = [0x1, 0x2, 0x3, 0x4, 0xB, 0xC]
1124  // %3 = [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
1125  //
1126  // TODO: Currently, only the even number of elements loading is supported.
1127  // To deal with the odd number of elements, one has to extract the
1128  // subvector at the proper offset after bit-casting.
1129  auto origType = op.getVectorType();
1130  auto origElements = origType.getNumElements();
1131  // Note, per-element-alignment was already verified above.
1132  bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
1133 
1134  auto stridedMetadata =
1135  memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1136  OpFoldResult linearizedIndices;
1137  memref::LinearizedMemRefInfo linearizedInfo;
1138  std::tie(linearizedInfo, linearizedIndices) =
1140  rewriter, loc, emulatedBits, containerBits,
1141  stridedMetadata.getConstifiedMixedOffset(),
1142  stridedMetadata.getConstifiedMixedSizes(),
1143  stridedMetadata.getConstifiedMixedStrides(),
1144  getAsOpFoldResult(adaptor.getIndices()));
1145 
1146  std::optional<int64_t> foldedIntraVectorOffset =
1147  isDivisibleInSize ? 0
1148  : getConstantIntValue(linearizedInfo.intraDataOffset);
1149 
1150  int64_t maxIntraDataOffset =
1151  foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1152  FailureOr<Operation *> newMask =
1153  getCompressedMaskOp(rewriter, loc, op.getMask(), origElements,
1154  emulatedPerContainerElem, maxIntraDataOffset);
1155  if (failed(newMask))
1156  return failure();
1157 
1158  Value passthru = op.getPassThru();
1159 
1160  auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1161  emulatedPerContainerElem);
1162  auto loadType = VectorType::get(numElements, containerElemTy);
1163  auto newBitcastType =
1164  VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
1165 
1166  auto emptyVector = arith::ConstantOp::create(
1167  rewriter, loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
1168  if (!foldedIntraVectorOffset) {
1169  passthru = dynamicallyInsertSubVector(
1170  rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,
1171  origElements);
1172  } else if (!isDivisibleInSize) {
1173  passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
1174  *foldedIntraVectorOffset);
1175  }
1176  auto newPassThru =
1177  vector::BitCastOp::create(rewriter, loc, loadType, passthru);
1178 
1179  // Generating the new masked load.
1180  auto newLoad = vector::MaskedLoadOp::create(
1181  rewriter, loc, loadType, adaptor.getBase(),
1182  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
1183  newMask.value()->getResult(0), newPassThru);
1184 
1185  // Setting the part that originally was not effectively loaded from memory
1186  // to pass through.
1187  auto bitCast =
1188  vector::BitCastOp::create(rewriter, loc, newBitcastType, newLoad);
1189 
1190  Value mask = op.getMask();
1191  auto newSelectMaskType = VectorType::get(
1192  numElements * emulatedPerContainerElem, rewriter.getI1Type());
1193  // TODO: try to fold if op's mask is constant
1194  auto emptyMask =
1195  arith::ConstantOp::create(rewriter, loc, newSelectMaskType,
1196  rewriter.getZeroAttr(newSelectMaskType));
1197  if (!foldedIntraVectorOffset) {
1198  mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask,
1199  linearizedInfo.intraDataOffset,
1200  origElements);
1201  } else if (!isDivisibleInSize) {
1202  mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
1203  *foldedIntraVectorOffset);
1204  }
1205 
1206  Value result =
1207  arith::SelectOp::create(rewriter, loc, mask, bitCast, passthru);
1208  if (!foldedIntraVectorOffset) {
1209  result = dynamicallyExtractSubVector(
1210  rewriter, loc, result, op.getPassThru(),
1211  linearizedInfo.intraDataOffset, origElements);
1212  } else if (!isDivisibleInSize) {
1213  result = staticallyExtractSubvector(
1214  rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1215  }
1216  rewriter.replaceOp(op, result);
1217 
1218  return success();
1219  }
1220 };
1221 
1222 /// Check whether `subByteVecTy` fits wthin a vector of `multiByteScalarTy`
1223 ///
1224 /// "Fitting" means that `subByteVecTy` (a vector of sub-byte elements, e.g.
1225 /// vector<4xi4>), can fit within N scalar elements of type `multiByteScalarTy`
1226 /// (a multi-byte scalar, e.g. i16), where N is some integer.
1227 ///
1228 /// Put differently, this method checks whether this would be valid:
1229 ///
1230 /// vector.bitcast subByteVecTy into vector<N x multiByteScalarTy>
1231 ///
1232 /// EXAMPLES:
1233 /// * vector<4xi4> -> i16 - yes (N = 1)
1234 /// * vector<4xi4> -> i8 - yes (N = 2)
1235 /// * vector<3xi4> -> i8 - no (N would have to be 1.5)
1236 /// * vector<3xi2> -> i16 - no (N would have to be 0.5)
1237 static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
1238  Type multiByteScalarTy) {
1239  assert((isa<IntegerType, FloatType>(multiByteScalarTy)) && "Not scalar!");
1240 
1241  int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
1242  int multiByteBits = multiByteScalarTy.getIntOrFloatBitWidth();
1243 
1244  assert(subByteBits < 8 && "Not a sub-byte scalar type!");
1245  assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
1246  assert(multiByteBits % subByteBits == 0 && "Unalagined element types!");
1247 
1248  int elemsPerMultiByte = multiByteBits / subByteBits;
1249 
1250  return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
1251 }
1252 
1253 //===----------------------------------------------------------------------===//
1254 // ConvertVectorTransferRead
1255 //===----------------------------------------------------------------------===//
1256 
1257 // TODO: Document-me
1258 struct ConvertVectorTransferRead final
1259  : OpConversionPattern<vector::TransferReadOp> {
1261 
1262  LogicalResult
1263  matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
1264  ConversionPatternRewriter &rewriter) const override {
1265 
1266  // Prerequisites: memref in the vector.transfer_read op is flattened into
1267  // 1-D.
1268  if (op.getVectorType().getRank() != 1)
1269  return rewriter.notifyMatchFailure(
1270  op, "Memref in emulated vector ops must be flattened beforehand.");
1271 
1272  auto loc = op.getLoc();
1273  auto containerElemTy =
1274  cast<MemRefType>(adaptor.getBase().getType()).getElementType();
1275  Type emulatedElemTy = op.getType().getElementType();
1276  int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
1277  int containerBits = containerElemTy.getIntOrFloatBitWidth();
1278 
1279  // Check per-element alignment.
1280  if (containerBits % emulatedBits != 0) {
1281  return rewriter.notifyMatchFailure(
1282  op, "impossible to pack emulated elements into container elements "
1283  "(bit-wise misalignment)");
1284  }
1285  int emulatedPerContainerElem = containerBits / emulatedBits;
1286 
1287  auto origElements = op.getVectorType().getNumElements();
1288 
1289  // Note, per-element-alignment was already verified above.
1290  bool isDivisibleInSize =
1291  fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
1292 
1293  // Pad the padding value with 0s on the left. These bits are discarded and
1294  // thus their values don't matter.
1295  Value padding = adaptor.getPadding();
1296  if (!padding.getType().isInteger()) {
1297  padding = arith::BitcastOp::create(
1298  rewriter, loc,
1299  IntegerType::get(rewriter.getContext(),
1300  padding.getType().getIntOrFloatBitWidth()),
1301  padding);
1302  }
1303  auto newPadding =
1304  arith::ExtUIOp::create(rewriter, loc, containerElemTy, padding);
1305 
1306  auto stridedMetadata =
1307  memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
1308 
1309  OpFoldResult linearizedIndices;
1310  memref::LinearizedMemRefInfo linearizedInfo;
1311  std::tie(linearizedInfo, linearizedIndices) =
1313  rewriter, loc, emulatedBits, containerBits,
1314  stridedMetadata.getConstifiedMixedOffset(),
1315  stridedMetadata.getConstifiedMixedSizes(),
1316  stridedMetadata.getConstifiedMixedStrides(),
1317  getAsOpFoldResult(adaptor.getIndices()));
1318 
1319  std::optional<int64_t> foldedIntraVectorOffset =
1320  isDivisibleInSize ? 0
1321  : getConstantIntValue(linearizedInfo.intraDataOffset);
1322 
1323  int64_t maxIntraDataOffset =
1324  foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1325  auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1326  emulatedPerContainerElem);
1327 
1328  auto newRead = vector::TransferReadOp::create(
1329  rewriter, loc, VectorType::get(numElements, containerElemTy),
1330  adaptor.getBase(),
1331  getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
1332  newPadding);
1333 
1334  auto bitCast = vector::BitCastOp::create(
1335  rewriter, loc,
1336  VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy),
1337  newRead);
1338 
1339  Value result = bitCast->getResult(0);
1340  if (!foldedIntraVectorOffset) {
1341  auto zeros = arith::ConstantOp::create(
1342  rewriter, loc, op.getType(), rewriter.getZeroAttr(op.getType()));
1343  result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
1344  linearizedInfo.intraDataOffset,
1345  origElements);
1346  } else if (!isDivisibleInSize) {
1347  result = staticallyExtractSubvector(
1348  rewriter, loc, result, *foldedIntraVectorOffset, origElements);
1349  }
1350  rewriter.replaceOp(op, result);
1351 
1352  return success();
1353  }
1354 };
1355 } // end anonymous namespace
1356 
1357 //===----------------------------------------------------------------------===//
1358 // RewriteBitCastOfTruncI
1359 //===----------------------------------------------------------------------===//
1360 
1361 namespace {
1362 
1363 /// Helper struct to keep track of the provenance of a contiguous set of bits
1364 /// in a source vector.
1365 struct SourceElementRange {
1366  /// The index of the source vector element that contributes bits to *this.
1367  int64_t sourceElementIdx;
1368  /// The range of bits in the source vector element that contribute to *this.
1369  int64_t sourceBitBegin;
1370  int64_t sourceBitEnd;
1371 };
1372 
1373 struct SourceElementRangeList : public SmallVector<SourceElementRange> {
1374  /// Given the index of a SourceElementRange in the SourceElementRangeList,
1375  /// compute the amount of bits that need to be shifted to the left to get the
1376  /// bits in their final location. This shift amount is simply the sum of the
1377  /// bits *before* `shuffleIdx` (i.e. the bits of `shuffleIdx = 0` are always
1378  /// the LSBs, the bits of `shuffleIdx = ` come next, etc).
1379  int64_t computeLeftShiftAmount(int64_t shuffleIdx) const {
1380  int64_t res = 0;
1381  for (int64_t i = 0; i < shuffleIdx; ++i)
1382  res += (*this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
1383  return res;
1384  }
1385 };
1386 
1387 /// Helper struct to enumerate the source elements and bit ranges that are
1388 /// involved in a bitcast operation.
1389 /// This allows rewriting a vector.bitcast into shuffles and bitwise ops for
1390 /// any 1-D vector shape and any source/target bitwidths.
1391 /// This creates and holds a mapping of the form:
1392 /// [dstVectorElementJ] ==
1393 /// [ {srcVectorElementX, bitRange}, {srcVectorElementY, bitRange}, ... ]
1394 /// E.g. `vector.bitcast ... : vector<1xi24> to vector<3xi8>` is decomposed as:
1395 /// [0] = {0, [0-8)}
1396 /// [1] = {0, [8-16)}
1397 /// [2] = {0, [16-24)}
1398 /// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as:
1399 /// [0] = {0, [0, 10)}, {1, [0, 5)}
1400 /// [1] = {1, [5, 10)}, {2, [0, 10)}
1401 struct BitCastBitsEnumerator {
1402  BitCastBitsEnumerator(VectorType sourceVectorType,
1403  VectorType targetVectorType);
1404 
1405  int64_t getMaxNumberOfEntries() {
1406  int64_t numVectors = 0;
1407  for (const auto &l : sourceElementRanges)
1408  numVectors = std::max(numVectors, (int64_t)l.size());
1409  return numVectors;
1410  }
1411 
1412  VectorType sourceVectorType;
1413  VectorType targetVectorType;
1414  SmallVector<SourceElementRangeList> sourceElementRanges;
1415 };
1416 
1417 /// Rewrite vector.bitcast to a sequence of shuffles and bitwise ops that take
1418 /// advantage of high-level information to avoid leaving LLVM to scramble with
1419 /// peephole optimizations.
1420 /// BitCastBitsEnumerator encodes for each element of the target vector the
1421 /// provenance of the bits in the source vector. We can "transpose" this
1422 /// information to build a sequence of shuffles and bitwise ops that will
1423 /// produce the desired result.
1424 //
1425 /// Consider the following motivating example:
1426 /// ```
1427 /// %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
1428 /// ```
1429 //
1430 /// BitCastBitsEnumerator contains the following information:
1431 /// ```
1432 /// { 0: b@[0..5) lshl: 0}{ 1: b@[0..3) lshl: 5}
1433 /// { 1: b@[3..5) lshl: 0}{ 2: b@[0..5) lshl: 2}{ 3: b@[0..1) lshl: 7}
1434 /// { 3: b@[1..5) lshl: 0}{ 4: b@[0..4) lshl: 4}
1435 /// { 4: b@[4..5) lshl: 0}{ 5: b@[0..5) lshl: 1}{ 6: b@[0..2) lshl: 6}
1436 /// { 6: b@[2..5) lshl: 0}{ 7: b@[0..5) lshl: 3}
1437 /// { 8: b@[0..5) lshl: 0}{ 9: b@[0..3) lshl: 5}
1438 /// { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7}
1439 /// {11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4}
1440 /// {12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6}
1441 /// {14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
1442 /// {16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
1443 /// {17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
1444 /// {19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
1445 /// {20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1}{22: b@[0..2) lshl: 6}
1446 /// {22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3}
1447 /// {24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5}
1448 /// {25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7}
1449 /// {27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
1450 /// {28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
1451 /// {30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3}
1452 /// ```
1453 ///
1454 /// In the above, each row represents one target vector element and each
1455 /// column represents one bit contribution from a source vector element.
1456 /// The algorithm creates vector.shuffle operations (in this case there are 3
1457 /// shuffles (i.e. the max number of columns in BitCastBitsEnumerator). The
1458 /// algorithm populates the bits as follows:
1459 /// ```
1460 /// src bits 0 ...
1461 /// 1st shuffle |xxxxx |xx |...
1462 /// 2nd shuffle | xxx| xxxxx |...
1463 /// 3rd shuffle | | x|...
1464 /// ```
1465 //
1466 /// The algorithm proceeds as follows:
1467 /// 1. for each vector.shuffle, collect the source vectors that participate in
1468 /// this shuffle. One source vector per target element of the resulting
1469 /// vector.shuffle. If there is no source element contributing bits for the
1470 /// current vector.shuffle, take 0 (i.e. row 0 in the above example has only
1471 /// 2 columns).
1472 /// 2. represent the bitrange in the source vector as a mask. If there is no
1473 /// source element contributing bits for the current vector.shuffle, take 0.
1474 /// 3. shift right by the proper amount to align the source bitrange at
1475 /// position 0. This is exactly the low end of the bitrange. For instance,
1476 /// the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
1477 /// shift right by 3 to get the bits contributed by the source element #1
1478 /// into position 0.
1479 /// 4. shift left by the proper amount to to align to the desired position in
1480 /// the result element vector. For instance, the contribution of the second
1481 /// source element for the first row needs to be shifted by `5` to form the
1482 /// first i8 result element.
1483 ///
1484 /// Eventually, we end up building the sequence
1485 /// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
1486 /// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
1487 /// bits extracted from the source vector (i.e. the `shuffle -> and` part).
1488 struct BitCastRewriter {
1489  /// Helper metadata struct to hold the static quantities for the rewrite.
1490  struct Metadata {
1491  SmallVector<int64_t> shuffles;
1492  SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1493  };
1494 
1495  BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
1496 
1497  /// Verify that general preconditions for the rewrite are met.
1498  LogicalResult commonPrecondition(PatternRewriter &rewriter,
1499  VectorType preconditionType, Operation *op);
1500 
1501  /// Precompute the metadata for the rewrite.
1503  precomputeMetadata(IntegerType shuffledElementType);
1504 
1505  /// Rewrite one step of the sequence:
1506  /// `(shuffle -> and -> shiftright -> shiftleft -> or)`.
1507  Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
1508  Value initialValue, Value runningResult,
1509  const BitCastRewriter::Metadata &metadata);
1510 
1511 private:
1512  /// Underlying enumerator that encodes the provenance of the bits in the each
1513  /// element of the result vector.
1514  BitCastBitsEnumerator enumerator;
1515 };
1516 
1517 } // namespace
1518 
1519 [[maybe_unused]] static raw_ostream &
1520 operator<<(raw_ostream &os, const SmallVector<SourceElementRangeList> &vec) {
1521  for (const auto &l : vec) {
1522  for (auto it : llvm::enumerate(l)) {
1523  os << "{ " << it.value().sourceElementIdx << ": b@["
1524  << it.value().sourceBitBegin << ".." << it.value().sourceBitEnd
1525  << ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } ";
1526  }
1527  os << "\n";
1528  }
1529  return os;
1530 }
1531 
1532 BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
1533  VectorType targetVectorType)
1534  : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
1535 
1536  assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
1537  "requires -D non-scalable vector type");
1538  assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
1539  "requires -D non-scalable vector type");
1540  int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
1541  int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
1542  LDBG() << "sourceVectorType: " << sourceVectorType;
1543 
1544  int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
1545  int64_t mostMinorTargetDim = targetVectorType.getShape().back();
1546  LDBG() << "targetVectorType: " << targetVectorType;
1547 
1548  int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
1549  (void)mostMinorSourceDim;
1550  assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
1551  "source and target bitwidths must match");
1552 
1553  // Prepopulate one source element range per target element.
1554  sourceElementRanges = SmallVector<SourceElementRangeList>(mostMinorTargetDim);
1555  for (int64_t resultBit = 0; resultBit < bitwidth;) {
1556  int64_t resultElement = resultBit / targetBitWidth;
1557  int64_t resultBitInElement = resultBit % targetBitWidth;
1558  int64_t sourceElementIdx = resultBit / sourceBitWidth;
1559  int64_t sourceBitInElement = resultBit % sourceBitWidth;
1560  int64_t step = std::min(sourceBitWidth - sourceBitInElement,
1561  targetBitWidth - resultBitInElement);
1562  sourceElementRanges[resultElement].push_back(
1563  {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
1564  resultBit += step;
1565  }
1566 }
1567 
1568 BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
1569  VectorType targetVectorType)
1570  : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
1571  LDBG() << "\n" << enumerator.sourceElementRanges;
1572 }
1573 
1574 /// Verify that the precondition type meets the common preconditions for any
1575 /// conversion.
1576 static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
1577  VectorType preconditionType,
1578  Operation *op) {
1579  if (!preconditionType || preconditionType.isScalable())
1580  return rewriter.notifyMatchFailure(op, "scalable vector");
1581 
1582  // TODO: consider relaxing this restriction in the future if we find ways
1583  // to really work with subbyte elements across the MLIR/LLVM boundary.
1584  unsigned bitwidth = preconditionType.getElementTypeBitWidth();
1585  if (bitwidth % 8 != 0)
1586  return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
1587 
1588  return success();
1589 }
1590 
1591 LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
1592  VectorType preconditionType,
1593  Operation *op) {
1594  if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
1595  return rewriter.notifyMatchFailure(op, "types are not vector");
1596 
1597  if (!preconditionType || preconditionType.getRank() != 1)
1598  return rewriter.notifyMatchFailure(op, "unsupported >1-D vector");
1599 
1600  return commonConversionPrecondition(rewriter, preconditionType, op);
1601 }
1602 
1603 /// Verify that `subByteVecTy` (vector) and `containerTy` (scalar) are aligned.
1604 ///
1605 /// Alignment means that `subByteVecTy` can be packed into a vector of
1606 /// `containerTy` elements. More specifically:
1607 /// 1. The bit-width of `containerTy` is a multiple of the
1608 /// bit-width of `subByteVecTy` elements. For example, for `i4` and `i16`
1609 /// this multiple is 4.
1610 /// 2. The multiple from 1. above divides evenly the number of the (trailing)
1611 /// elements in `subByteVecTy`.
1612 ///
1613 /// EXAMPLE 1:
1614 /// `subByteVecTy = vector<2xi4>`, and
1615 /// `containerTy = i16`
1616 ///
1617 /// 2 divides evenly 4 ( = 16 / 4), hence both conditions are _met_.
1618 ///
1619 /// EXAMPLE 2:
1620 /// `subByteVecTy = vector<3xi4>`, and
1621 /// `containerTy = i16`
1622 ///
1623 /// 3 _does not_ divide evenly 4 (= 16/4), hence the conditions are _not met_.
1624 ///
1625 /// EXAMPLE 3:
1626 /// `subByteVecTy = vector<3xi3>`, and
1627 /// `containerTy = i16`
1628 ///
1629 /// 16 _is not_ a multiple of 3, hence the conditions are _not met_.
1630 ///
1631 /// NOTE: This method assumes that common conversion preconditions are met. In
1632 /// particular, `containerTy` is assumed to be a
1633 /// multi-byte scalar type (e.g., i8, i16, i32).
1634 static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
1635  VectorType subByteVecTy,
1636  Type containerTy,
1637  Operation *op) {
1638  assert(containerTy.isIntOrFloat() &&
1639  "container element type is not a scalar");
1640 
1641  // TODO: This is validating the inputs rather than checking the conditions
1642  // documented above. Replace with an assert.
1643  if (!subByteVecTy)
1644  return rewriter.notifyMatchFailure(op, "not a vector!");
1645 
1646  unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
1647  unsigned containerBits = containerTy.getIntOrFloatBitWidth();
1648 
1649  // Enforced by the common pre-conditions.
1650  assert(containerBits % 8 == 0 && "Not a multi-byte scalar type!");
1651 
1652  // TODO: Add support other widths (when/if needed)
1653  if (subByteBits != 2 && subByteBits != 4)
1654  return rewriter.notifyMatchFailure(
1655  op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
1656 
1657  // Condition 1 ("per-element" alignment)
1658  if (containerBits % subByteBits != 0)
1659  return rewriter.notifyMatchFailure(op, "unalagined element types");
1660 
1661  // Condition 2 ("full" alignment)
1662  if (!fitsInMultiByteContainerTy(subByteVecTy, containerTy))
1663  return rewriter.notifyMatchFailure(
1664  op, "not possible to fit this sub-byte vector type into a vector of "
1665  "the given multi-byte type");
1666 
1667  return success();
1668 }
1669 
1671 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
1673  for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
1674  shuffleIdx < e; ++shuffleIdx) {
1675  SmallVector<int64_t> shuffles;
1676  SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
1677 
1678  // Create the attribute quantities for the shuffle / mask / shift ops.
1679  for (auto &srcEltRangeList : enumerator.sourceElementRanges) {
1680  int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size())
1681  ? srcEltRangeList[shuffleIdx].sourceElementIdx
1682  : 0;
1683  shuffles.push_back(sourceElement);
1684 
1685  int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
1686  ? srcEltRangeList[shuffleIdx].sourceBitBegin
1687  : 0;
1688  int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
1689  ? srcEltRangeList[shuffleIdx].sourceBitEnd
1690  : 0;
1691  IntegerAttr mask = IntegerAttr::get(
1692  shuffledElementType,
1693  llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
1694  bitLo, bitHi));
1695  masks.push_back(mask);
1696 
1697  int64_t shiftRight = bitLo;
1698  shiftRightAmounts.push_back(
1699  IntegerAttr::get(shuffledElementType, shiftRight));
1700 
1701  int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
1702  shiftLeftAmounts.push_back(
1703  IntegerAttr::get(shuffledElementType, shiftLeft));
1704  }
1705 
1706  result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
1707  }
1708  return result;
1709 }
1710 
1711 Value BitCastRewriter::genericRewriteStep(
1712  PatternRewriter &rewriter, Location loc, Value initialValue,
1713  Value runningResult, const BitCastRewriter::Metadata &metadata) {
1714  // Create vector.shuffle from the metadata.
1715  auto shuffleOp = vector::ShuffleOp::create(rewriter, loc, initialValue,
1716  initialValue, metadata.shuffles);
1717 
1718  // Intersect with the mask.
1719  VectorType shuffledVectorType = shuffleOp.getResultVectorType();
1720  auto constOp = arith::ConstantOp::create(
1721  rewriter, loc,
1722  DenseElementsAttr::get(shuffledVectorType, metadata.masks));
1723  Value andValue = arith::AndIOp::create(rewriter, loc, shuffleOp, constOp);
1724 
1725  // Align right on 0.
1726  auto shiftRightConstantOp = arith::ConstantOp::create(
1727  rewriter, loc,
1728  DenseElementsAttr::get(shuffledVectorType, metadata.shiftRightAmounts));
1729  Value shiftedRight =
1730  arith::ShRUIOp::create(rewriter, loc, andValue, shiftRightConstantOp);
1731 
1732  // Shift bits left into their final position.
1733  auto shiftLeftConstantOp = arith::ConstantOp::create(
1734  rewriter, loc,
1735  DenseElementsAttr::get(shuffledVectorType, metadata.shiftLeftAmounts));
1736  Value shiftedLeft =
1737  arith::ShLIOp::create(rewriter, loc, shiftedRight, shiftLeftConstantOp);
1738 
1739  runningResult =
1740  runningResult
1741  ? arith::OrIOp::create(rewriter, loc, runningResult, shiftedLeft)
1742  : shiftedLeft;
1743 
1744  return runningResult;
1745 }
1746 
1747 /// Bitcasts the aligned `subByteVec` vector to a vector of i8.
1748 /// Where aligned means it satisfies the alignedConversionPreconditions.
1749 ///
1750 /// Example:
1751 /// vector<16x16xi2> -> vector<16x4xi8>
1752 /// vector<16x16xi4> -> vector<16x8xi8>
1754  Value subByteVec) {
1755  auto srcVecType = cast<VectorType>(subByteVec.getType());
1756  int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
1757  assert(8 % srcBitwidth == 0 &&
1758  "Unsupported sub-byte type (not a divisor of i8)");
1759  int64_t numSrcElemsPerByte = 8 / srcBitwidth;
1760  SmallVector<int64_t> vecShape(srcVecType.getShape());
1761  // Adjust last dimension of the vector, so the total size remains the same.
1762  vecShape.back() = vecShape.back() / numSrcElemsPerByte;
1763  auto i8VecType = VectorType::get(vecShape, rewriter.getI8Type());
1764  return vector::BitCastOp::create(rewriter, loc, i8VecType, subByteVec);
1765 }
1766 
1767 /// Extracts a signed N-bit sequence from each element of a vector of bytes,
1768 /// starting at the specified bit index.
1769 /// The `bitIdx` starts at 0 from the LSB and moves to the left.
1770 ///
1771 /// Example for a single element:
1772 /// Extract numBits=2 starting at bitIdx=2
1773 /// src = [0 | 1 | 0 | 1 | 1 | 1 | 1 | 0]
1774 /// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1775 /// target = [. . . . ^ ^ . .]
1776 ///
1777 /// The target sequence is [11](decimal=-1) as signed 2-bit integer.
1778 /// So the result should be [11 11 11 11](decimal=-1) as signed 8-bit integer.
1779 ///
1780 /// src = [01 01 11 10]
1781 /// shl = arith.shl(src, 4) -> [11 10 00 00]
1782 /// result = arith.shrsi(shl, 6) -> [11 11 11 11]
1784  Location loc, Value src,
1785  int bitIdx, int numBits) {
1786  auto srcType = cast<VectorType>(src.getType());
1787  Value shl = src;
1788  int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
1789  assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 &&
1790  "Invalid bitIdx range");
1791  if (bitsToShiftLeft != 0) {
1792  Value shiftLeftValues = arith::ConstantOp::create(
1793  rewriter, loc, DenseElementsAttr::get(srcType, bitsToShiftLeft));
1794  shl = arith::ShLIOp::create(rewriter, loc, src, shiftLeftValues);
1795  }
1796 
1797  int8_t bitsToShiftRight = 8 - numBits;
1798  Value shiftRightValues = arith::ConstantOp::create(
1799  rewriter, loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
1800  Value shr = arith::ShRSIOp::create(rewriter, loc, shl, shiftRightValues);
1801  return shr;
1802 }
1803 
1804 /// Extracts an unsigned N-bit sequence from each element of a vector of bytes,
1805 /// starting at the specified bit index.
1806 /// The `bitIdx` starts at 0 from the LSB and moves to the left.
1807 ///
1808 /// Example for a single element:
1809 /// Extract numBits=2 starting at bitIdx=2
1810 /// src = [0 | 1 | 0 | 1 | 1 | 0 | 1 | 0]
1811 /// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1812 /// target = [. . . . ^ ^ . .]
1813 ///
1814 /// The target sequence is [10](decimal=2) as unsigned 2-bit integer.
1815 /// So the result should be [00 00 00 10](decimal=2) as unsigned 8-bit integer.
1816 ///
1817 /// src = [01 01 10 10]
1818 /// mask = [00 00 00 11]
1819 /// shr = arith.shrui(src, 2) = [00 01 01 10]
1820 /// result = arith.andi(shr, mask) = [00 00 00 10]
1821 /// NOTE: Similarly to extractNBitsPerByteAndSignExtendToI8, this could be
1822 /// achieved by using arith::ShLIOp + arith::ShRUIOp instead of the masking.
1823 /// However, by using arith::ShRUIOp + arith::AndIOp, we are eliminating shift
1824 /// left when the index is 0.
1826  Location loc, Value src,
1827  int bitIdx, int numBits) {
1828  assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
1829  "Invalid bitIdx range");
1830  auto srcType = cast<VectorType>(src.getType());
1831  int8_t bitsToShiftRight = bitIdx;
1832  Value shr = src;
1833  if (bitsToShiftRight != 0) {
1834  Value shiftRightValues = arith::ConstantOp::create(
1835  rewriter, loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
1836  shr = arith::ShRUIOp::create(rewriter, loc, src, shiftRightValues);
1837  }
1838  if (bitIdx + numBits == 8) {
1839  return shr;
1840  }
1841  uint8_t lowBitsMask = (1 << numBits) - 1;
1842  Value lowBitsMaskValues = arith::ConstantOp::create(
1843  rewriter, loc, DenseElementsAttr::get(srcType, lowBitsMask));
1844  return arith::AndIOp::create(rewriter, loc, shr, lowBitsMaskValues);
1845 }
1846 
1848  std::function<Value(PatternRewriter &, Location, Value, int, int)>;
1849 
1850 /// Rewrite the i4 -> i8 extension into a sequence of shuffles and
1851 /// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1853  Value srcValue, const ExtractNBitsFn &extFn) {
1854  [[maybe_unused]] auto srcVecType = cast<VectorType>(srcValue.getType());
1855  assert(srcVecType.getElementType().isSignlessInteger(4) &&
1856  "Expected i4 type");
1857 
1858  // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1859  Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
1860 
1861  // 2. Extend i4 elements to i8 elements. Low i4 elemens of each
1862  // byte are place in one vector and the high i4 elements in another vector.
1863  Value low = extFn(rewriter, loc, i8Vector, 0, 4);
1864  Value high = extFn(rewriter, loc, i8Vector, 4, 4);
1865 
1866  // 3. Interleave low and high i8 elements.
1867  return vector::InterleaveOp::create(rewriter, loc, low, high);
1868 }
1869 
1870 /// Rewrite the i2 -> i8 extension into a sequence of shuffles and
1871 /// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1873  Value srcValue, const ExtractNBitsFn &extFn) {
1874  [[maybe_unused]] VectorType srcVecType = cast<VectorType>(srcValue.getType());
1875  assert(srcVecType.getElementType().isSignlessInteger(2) &&
1876  "Expected i2 type");
1877 
1878  // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
1879  Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
1880 
1881  // 2. Extract each i2 element
1882  // Positon 0 (bits 0-1)
1883  Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2);
1884  // Position 1 (bits 2-3)
1885  Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2);
1886  // Position 2 (bits 4-5)
1887  Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2);
1888  // Position 3 (bits 6-7)
1889  Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2);
1890 
1891  // 3. Interleave all 4 elements by first interleaving
1892  // even elements and then odd
1893  // vec0 = [0,0,0,0],...
1894  // vec1 = [1,1,1,1],...
1895  // vec2 = [2,2,2,2],...
1896  // vec3 = [3,3,3,3],...
1897  // 02 = [0,2,0,2,0,2,0,2],...
1898  // 13 = [1,3,1,3,1,3,1,3],...
1899  // 0213 = [0,1,2,3,...],...
1900  Value interleave02 = vector::InterleaveOp::create(rewriter, loc, vec0, vec2);
1901  Value interleave13 = vector::InterleaveOp::create(rewriter, loc, vec1, vec3);
1902  return vector::InterleaveOp::create(rewriter, loc, interleave02,
1903  interleave13);
1904 }
1905 
1906 /// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
1907 /// ops to avoid leaving LLVM to scramble with peephole optimizations.
1909  Value srcValue) {
1910  VectorType srcVecType = cast<VectorType>(srcValue.getType());
1911  assert(srcVecType.getElementType().isSignlessInteger(8) &&
1912  "Expected i8 type");
1913 
1914  // 1. De-interleave low and high i8 elements.
1915  auto deinterleaveOp = vector::DeinterleaveOp::create(rewriter, loc, srcValue);
1916 
1917  // 2. Zero out the upper side of each low i8 element.
1918  constexpr int8_t i8LowBitMask = 0x0F;
1919  VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
1920  Value zeroOutMask = arith::ConstantOp::create(
1921  rewriter, loc, DenseElementsAttr::get(deinterI8VecType, i8LowBitMask));
1922  Value zeroOutLow = arith::AndIOp::create(
1923  rewriter, loc, deinterleaveOp.getRes1(), zeroOutMask);
1924 
1925  // 3. Move high i4 values to upper side of the byte.
1926  constexpr int8_t bitsToShift = 4;
1927  auto shiftValues = arith::ConstantOp::create(
1928  rewriter, loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
1929  Value shlHigh = arith::ShLIOp::create(rewriter, loc, deinterleaveOp.getRes2(),
1930  shiftValues);
1931 
1932  // 4. Merge high and low i4 values.
1933  auto mergedHiLowOp = arith::OrIOp::create(rewriter, loc, zeroOutLow, shlHigh);
1934 
1935  // 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
1936  auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());
1937  return vector::BitCastOp::create(rewriter, loc, i4VecType, mergedHiLowOp);
1938 }
1939 
1940 namespace {
1941 /// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
1942 /// advantage of high-level information to avoid leaving LLVM to scramble with
1943 /// peephole optimizations.
1944 struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
1946 
1947  LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
1948  PatternRewriter &rewriter) const override {
1949  // The source must be a trunc op.
1950  auto truncOp =
1951  bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
1952  if (!truncOp)
1953  return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
1954 
1955  // Set up the BitCastRewriter and verify the precondition.
1956  VectorType sourceVectorType = bitCastOp.getSourceVectorType();
1957  VectorType targetVectorType = bitCastOp.getResultVectorType();
1958  BitCastRewriter bcr(sourceVectorType, targetVectorType);
1959  if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
1960  return failure();
1961 
1962  // Perform the rewrite.
1963  Value truncValue = truncOp.getIn();
1964  auto shuffledElementType =
1965  cast<IntegerType>(getElementTypeOrSelf(truncValue.getType()));
1966  Value runningResult;
1967  for (const BitCastRewriter ::Metadata &metadata :
1968  bcr.precomputeMetadata(shuffledElementType)) {
1969  runningResult = bcr.genericRewriteStep(
1970  rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
1971  }
1972 
1973  // Finalize the rewrite.
1974  bool narrowing = targetVectorType.getElementTypeBitWidth() <=
1975  shuffledElementType.getIntOrFloatBitWidth();
1976  if (narrowing) {
1977  if (runningResult.getType() == bitCastOp.getResultVectorType()) {
1978  rewriter.replaceOp(bitCastOp, runningResult);
1979  } else {
1980  rewriter.replaceOpWithNewOp<arith::TruncIOp>(
1981  bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1982  }
1983  } else {
1984  if (runningResult.getType() == bitCastOp.getResultVectorType()) {
1985  rewriter.replaceOp(bitCastOp, runningResult);
1986  } else {
1987  rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
1988  bitCastOp, bitCastOp.getResultVectorType(), runningResult);
1989  }
1990  }
1991 
1992  return success();
1993  }
1994 };
1995 } // namespace
1996 
1997 //===----------------------------------------------------------------------===//
1998 // RewriteExtOfBitCast
1999 //===----------------------------------------------------------------------===//
2000 
2001 namespace {
2002 /// Rewrite ext{s,u}i(bitcast) to a sequence of shuffles and bitwise ops that
2003 /// take advantage of high-level information to avoid leaving LLVM to scramble
2004 /// with peephole optimizations.
2005 template <typename ExtOpType>
2006 struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
2008 
2009  RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
2010  : OpRewritePattern<ExtOpType>(context, benefit) {}
2011 
2012  LogicalResult matchAndRewrite(ExtOpType extOp,
2013  PatternRewriter &rewriter) const override {
2014  // The source must be a bitcast op.
2015  auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>();
2016  if (!bitCastOp)
2017  return rewriter.notifyMatchFailure(extOp, "not a bitcast source");
2018 
2019  // Set up the BitCastRewriter and verify the precondition.
2020  VectorType sourceVectorType = bitCastOp.getSourceVectorType();
2021  VectorType targetVectorType = bitCastOp.getResultVectorType();
2022  BitCastRewriter bcr(sourceVectorType, targetVectorType);
2023  if (failed(bcr.commonPrecondition(
2024  rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
2025  return failure();
2026 
2027  // Perform the rewrite.
2028  Value runningResult;
2029  Value sourceValue = bitCastOp.getSource();
2030  auto shuffledElementType =
2031  cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
2032  for (const BitCastRewriter::Metadata &metadata :
2033  bcr.precomputeMetadata(shuffledElementType)) {
2034  runningResult = bcr.genericRewriteStep(
2035  rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
2036  }
2037 
2038  // Finalize the rewrite.
2039  bool narrowing =
2040  cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
2041  shuffledElementType.getIntOrFloatBitWidth();
2042  if (narrowing) {
2043  rewriter.replaceOpWithNewOp<arith::TruncIOp>(
2044  extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2045  } else {
2046  rewriter.replaceOpWithNewOp<ExtOpType>(
2047  extOp, cast<VectorType>(extOp.getOut().getType()), runningResult);
2048  }
2049 
2050  return success();
2051  }
2052 };
2053 
2054 /// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
2055 /// bitwise ops that take advantage of high-level information to avoid leaving
2056 /// LLVM to scramble with peephole optimizations. Templated to choose between
2057 /// signed and unsigned conversions.
2058 ///
2059 /// EXAMPLE 1 (signed):
2060 /// arith.extsi %in : vector<8xi4> to vector<8xi32>
2061 /// is rewriten as:
2062 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
2063 /// %1 = arith.shli %0, 4 : vector<4xi8>
2064 /// %2 = arith.shrsi %1, 4 : vector<4xi8>
2065 /// %3 = arith.shrsi %0, 4 : vector<4xi8>
2066 /// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
2067 /// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
2068 ///
2069 /// EXAMPLE 2 (fp):
2070 /// arith.sitofp %in : vector<8xi4> to vector<8xf32>
2071 /// is rewriten as:
2072 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
2073 /// %1 = arith.shli %0, 4 : vector<4xi8>
2074 /// %2 = arith.shrsi %1, 4 : vector<4xi8>
2075 /// %3 = arith.shrsi %0, 4 : vector<4xi8>
2076 /// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
2077 /// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
2078 ///
2079 /// EXAMPLE 3 (unsigned):
2080 /// arith.extui %in : vector<8xi4> to vector<8xi32>
2081 /// is rewritten as:
2082 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
2083 /// %1 = arith.andi %0, 15 : vector<4xi8>
2084 /// %2 = arith.shrui %0, 4 : vector<4xi8>
2085 /// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
2086 /// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
2087 ///
2088 template <typename ConversionOpType, bool isSigned>
2089 struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
2091 
2092  LogicalResult matchAndRewrite(ConversionOpType conversionOp,
2093  PatternRewriter &rewriter) const override {
2094  // Verify the preconditions.
2095  Value srcValue = conversionOp.getIn();
2096  VectorType srcVecType = dyn_cast<VectorType>(srcValue.getType());
2097  VectorType dstVecType = dyn_cast<VectorType>(conversionOp.getType());
2098 
2099  if (failed(
2100  commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
2101  return failure();
2102 
2103  // Check general alignment preconditions.
2105  rewriter, srcVecType,
2106  /*containerTy=*/rewriter.getI8Type(), conversionOp)))
2107  return failure();
2108 
2109  // Perform the rewrite.
2110  Location loc = conversionOp.getLoc();
2111  const auto &extFn = isSigned ? extractNBitsPerByteAndSignExtendToI8
2113  Value subByteExt;
2114  switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
2115  case 2:
2116  subByteExt = rewriteI2ToI8Ext(rewriter, loc, srcValue, extFn);
2117  break;
2118  case 4:
2119  subByteExt = rewriteI4ToI8Ext(rewriter, loc, srcValue, extFn);
2120  break;
2121  default:
2122  return failure();
2123  }
2124 
2125  // Finalize the rewrite.
2126  rewriter.replaceOpWithNewOp<ConversionOpType>(
2127  conversionOp, conversionOp.getType(), subByteExt);
2128  return success();
2129  }
2130 };
2131 
2132 /// Rewrite the i8 -> i4 part of any truncation into a deinterleave and
2133 /// bitwise ops that take advantage of high-level information to avoid leaving
2134 /// LLVM to scramble with peephole optimizations.
2135 ///
2136 /// For example:
2137 /// arith.trunci %in : vector<8xi32> to vector<8xi4>
2138 ///
2139 /// is rewriten as:
2140 ///
2141 /// %cst = arith.constant dense<15> : vector<4xi8>
2142 /// %cst_0 = arith.constant dense<4> : vector<4xi8>
2143 /// %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
2144 /// %2 = arith.andi %0, %cst : vector<4xi8>
2145 /// %3 = arith.shli %1, %cst_0 : vector<4xi8>
2146 /// %4 = arith.ori %2, %3 : vector<4xi8>
2147 /// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
2148 ///
2149 struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
2151 
2152  LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
2153  PatternRewriter &rewriter) const override {
2154  // Verify the preconditions.
2155  Value srcValue = truncOp.getIn();
2156  auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
2157  auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
2158  if (!srcVecType || !dstVecType)
2159  return failure();
2160 
2161  if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
2162  return failure();
2163 
2164  // TODO: Add support for truncating to i2.
2165  if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
2166  return failure();
2167 
2168  // Check general alignment preconditions. We invert the src/dst type order
2169  // to reuse the existing precondition logic.
2171  rewriter, dstVecType,
2172  /*containerTy=*/rewriter.getI8Type(), truncOp)))
2173  return failure();
2174 
2175  // Create a new iX -> i8 truncation op.
2176  Location loc = truncOp.getLoc();
2177  auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
2178  Value i8TruncVal =
2179  arith::TruncIOp::create(rewriter, loc, i8VecType, srcValue);
2180 
2181  // Rewrite the i8 -> i4 truncation part.
2182  Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal);
2183 
2184  // Finalize the rewrite.
2185  rewriter.replaceOp(truncOp, subByteTrunc);
2186  return success();
2187  }
2188 };
2189 
2190 /// Rewrite a sub-byte vector transpose into a sequence of instructions that
2191 /// perform the transpose on wider (byte) element types.
2192 ///
2193 /// EXAMPLE:
2194 /// %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
2195 ///
2196 /// is rewritten as:
2197 ///
2198 /// %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
2199 /// %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
2200 /// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4>
2201 ///
2202 struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
2204 
2205  RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
2206  : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
2207 
2208  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
2209  PatternRewriter &rewriter) const override {
2210  // Precondition: sub-byte integer transpose.
2211  constexpr unsigned minNativeBitwidth = 8;
2212  VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
2213  if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
2214  srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
2215  return rewriter.notifyMatchFailure(transposeOp,
2216  "not a sub-byte transpose");
2217  }
2218 
2219  // Perform the rewrite.
2220  Location loc = transposeOp.getLoc();
2221  // Signed/unsigned interpretation shouldn't matter here as we are just
2222  // transposing the elements and truncating them back to the original size.
2223  // TODO: Use unsigned extension (more efficient) when emulation or backend
2224  // support is available.
2225  auto srcNativeVecType = srcSubByteVecType.cloneWith(
2226  std::nullopt, rewriter.getIntegerType(minNativeBitwidth));
2227  Value extOp = arith::ExtSIOp::create(rewriter, loc, srcNativeVecType,
2228  transposeOp.getVector());
2229  Value newTranspose = vector::TransposeOp::create(
2230  rewriter, loc, extOp, transposeOp.getPermutation());
2231  VectorType dstSubByteVecType = transposeOp.getResultVectorType();
2232  rewriter.replaceOpWithNewOp<arith::TruncIOp>(transposeOp, dstSubByteVecType,
2233  newTranspose);
2234  return success();
2235  }
2236 };
2237 
2238 } // namespace
2239 
2240 //===----------------------------------------------------------------------===//
2241 // Public Interface Definition
2242 //===----------------------------------------------------------------------===//
2243 
2244 // The emulated type is inferred from the converted memref type.
2245 void vector::populateVectorNarrowTypeEmulationPatterns(
2246  const arith::NarrowTypeEmulationConverter &typeConverter,
2247  RewritePatternSet &patterns, bool disableAtomicRMW) {
2248  // Populate `vector.*` conversion patterns.
2249  // TODO: #119553 support atomicity
2250  patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
2251  ConvertVectorMaskedStore, ConvertVectorTransferRead>(
2252  typeConverter, patterns.getContext());
2253 
2254  // Populate `vector.*` store conversion patterns. The caller can choose
2255  // to avoid emitting atomic operations and reduce it to read-modify-write
2256  // sequence for stores if it is known there are no thread contentions.
2257  patterns.insert<ConvertVectorStore>(patterns.getContext(), disableAtomicRMW);
2258 }
2259 
2260 void vector::populateVectorNarrowTypeRewritePatterns(
2262  // TODO: Document what the emulated type is.
2263  patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
2264  RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
2265  benefit);
2266 
2267  // Patterns for aligned cases. We set higher priority as they are expected to
2268  // generate better performance for aligned cases.
2269  // The container type is always i8.
2270  patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, /*isSigned=*/true>,
2271  RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
2272  RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
2273  benefit.getBenefit() + 1);
2274  // The container type is always i8.
2275  patterns
2276  .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>,
2277  RewriteAlignedSubByteIntExt<arith::UIToFPOp, /*isSigned=*/false>>(
2278  patterns.getContext(), benefit.getBenefit() + 1);
2279 }
2280 
2281 // The container type is always i8.
2282 void vector::populateVectorTransposeNarrowTypeRewritePatterns(
2284  patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
2285 }
2286 
2287 void vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
2288  arith::NarrowTypeEmulationConverter &typeConverter,
2291  vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
2292 }
static Type getElementType(Type type)
Determine the element type of type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter, Location loc, VectorValue vector, int64_t extractOffset, int64_t sliceNumElements, int64_t insertOffset)
Extract sliceNumElements from source vector at extractOffset, and insert it into an empty vector at i...
std::function< Value(PatternRewriter &, Location, Value, int, int)> ExtractNBitsFn
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, Value srcValue)
Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise ops to avoid leaving LLVM t...
TypedValue< MemRefType > MemRefValue
static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, OpFoldResult linearizedIndices, int64_t numContainerElemsToLoad, Type emulatedElemTy, Type containerElemTy)
Emulate a vector load for emulatedElemTy using containerElemTy
TypedValue< VectorType > VectorValue
static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc, VectorType downcastType, VectorType upcastType, Value mask, Value trueValue, Value falseValue)
Downcast two values to downcastType, then select values based on mask, and casts the result to upcast...
static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i4 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToInsert)
Inserts 1-D subvector into a 1-D vector.
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, Value src, Value dest, int64_t offset)
Inserts 1-D subvector into a 1-D vector.
static void atomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value storeIdx, VectorValue valueToStore, Value mask)
Emits memref.generic_atomic_rmw op to store a subbyte-sized value to a byte in linearizedMemref,...
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, Value src, int64_t offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op)
Verify that the precondition type meets the common preconditions for any conversion.
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, Value src, Value dest, OpFoldResult offset, int64_t numElemsToExtract)
Extracts 1-D subvector from a 1-D vector.
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType subByteVecTy, Type containerTy, Operation *op)
Verify that subByteVecTy (vector) and containerTy (scalar) are aligned.
static void nonAtomicRMW(OpBuilder &builder, Location loc, MemRefValue linearizedMemref, Value linearizedIndex, VectorValue valueToStore, Value mask)
Generate a non-atomic read-modify-write sequence for storing to the emulated type.
static FailureOr< Operation * > getCompressedMaskOp(OpBuilder &rewriter, Location loc, Value mask, int numSrcElems, int numSrcElemsPerDest, int numFrontPadElems=0)
Returns a compressed mask for the emulated vector.
static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc, Value subByteVec)
Bitcasts the aligned subByteVec vector to a vector of i8.
static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts an unsigned N-bit sequence from each element of a vector of bytes, starting at the specified...
static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc, Value srcValue, const ExtractNBitsFn &extFn)
Rewrite the i2 -> i8 extension into a sequence of shuffles and bitwise ops to avoid leaving LLVM to s...
static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter, Location loc, Value src, int bitIdx, int numBits)
Extracts a signed N-bit sequence from each element of a vector of bytes, starting at the specified bi...
Base type for affine expression.
Definition: AffineExpr.h:68
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
IntegerType getI4Type()
Definition: Builders.cpp:56
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:323
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:52
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:280
IndexType getIndexType()
Definition: Builders.cpp:50
IntegerType getI8Type()
Definition: Builders.cpp:58
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_type_range getResultTypes()
Definition: Operation.h:428
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1329
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Definition: MemRefUtils.cpp:51
void populateFlattenVectorOpsOnMemrefPatterns(RewritePatternSet &patterns)
Patterns for flattening multi-dimensional memref operations into one-dimensional memref operations.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Definition: CodegenUtils.h:320
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:322
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
Definition: PatternMatch.h:297
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
Definition: MemRefUtils.h:50
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.