MLIR  14.0.0git
VectorOps.cpp
Go to the documentation of this file.
1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
11 //
12 //===----------------------------------------------------------------------===//
13 
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/Builders.h"
26 #include "mlir/IR/BuiltinOps.h"
27 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/IR/PatternMatch.h"
31 #include "mlir/IR/TypeUtilities.h"
32 #include "mlir/Support/LLVM.h"
34 #include "llvm/ADT/StringSet.h"
35 #include "llvm/ADT/bit.h"
36 #include <numeric>
37 
38 #include "mlir/Dialect/Vector/VectorOpsDialect.cpp.inc"
39 // Pull in all enum type and utility function definitions.
40 #include "mlir/Dialect/Vector/VectorOpsEnums.cpp.inc"
41 
42 using namespace mlir;
43 using namespace mlir::vector;
44 
45 /// Helper enum to classify mask value.
46 enum class MaskFormat {
47  AllTrue = 0,
48  AllFalse = 1,
49  Unknown = 2,
50 };
51 
52 /// Helper method to classify a 1-D mask value. Currently, the method
53 /// looks "under the hood" of a constant value with dense attributes
54 /// and a constant mask operation (since the client may be called at
55 /// various stages during progressive lowering).
57  if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
58  // Inspect constant dense values. We count up for bits that
59  // are set, count down for bits that are cleared, and bail
60  // when a mix is detected.
61  if (auto denseElts = c.getValue().dyn_cast<DenseIntElementsAttr>()) {
62  int64_t val = 0;
63  for (bool b : denseElts.getValues<bool>())
64  if (b && val >= 0)
65  val++;
66  else if (!b && val <= 0)
67  val--;
68  else
69  return MaskFormat::Unknown;
70  if (val > 0)
71  return MaskFormat::AllTrue;
72  if (val < 0)
73  return MaskFormat::AllFalse;
74  }
75  } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
76  // Inspect constant mask index. If the index exceeds the
77  // dimension size, all bits are set. If the index is zero
78  // or less, no bits are set.
79  ArrayAttr masks = m.mask_dim_sizes();
80  assert(masks.size() == 1);
81  int64_t i = masks[0].cast<IntegerAttr>().getInt();
82  int64_t u = m.getType().getDimSize(0);
83  if (i >= u)
84  return MaskFormat::AllTrue;
85  if (i <= 0)
86  return MaskFormat::AllFalse;
87  }
88  return MaskFormat::Unknown;
89 }
90 
91 // Helper for verifying combining kinds in contractions and reductions.
92 static bool isSupportedCombiningKind(CombiningKind combiningKind,
93  Type elementType) {
94  switch (combiningKind) {
95  case CombiningKind::ADD:
96  case CombiningKind::MUL:
97  return elementType.isIntOrIndexOrFloat();
98  case CombiningKind::MINUI:
99  case CombiningKind::MINSI:
100  case CombiningKind::MAXUI:
101  case CombiningKind::MAXSI:
102  case CombiningKind::AND:
103  case CombiningKind::OR:
104  case CombiningKind::XOR:
105  return elementType.isIntOrIndex();
106  case CombiningKind::MINF:
107  case CombiningKind::MAXF:
108  return elementType.isa<FloatType>();
109  }
110  return false;
111 }
112 
113 /// Return true if the last dimension of the MemRefType has unit stride. Also
114 /// return true for memrefs with no strides.
116  int64_t offset;
117  SmallVector<int64_t> strides;
118  auto successStrides = getStridesAndOffset(type, strides, offset);
119  return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
120 }
121 
122 //===----------------------------------------------------------------------===//
123 // CombiningKindAttr
124 //===----------------------------------------------------------------------===//
125 
126 namespace mlir {
127 namespace vector {
128 namespace detail {
130  using KeyTy = uint64_t;
131 
133 
134  bool operator==(const KeyTy &key) const { return value == key; }
135 
137  const KeyTy &key) {
138  return new (allocator.allocate<BitmaskEnumStorage>())
139  BitmaskEnumStorage(key);
140  }
141 
143 };
144 } // namespace detail
145 } // namespace vector
146 } // namespace mlir
147 
149  MLIRContext *context) {
150  return Base::get(context, static_cast<uint64_t>(kind));
151 }
152 
153 CombiningKind CombiningKindAttr::getKind() const {
154  return static_cast<CombiningKind>(getImpl()->value);
155 }
156 
157 static constexpr const CombiningKind combiningKindsList[] = {
158  // clang-format off
159  CombiningKind::ADD,
160  CombiningKind::MUL,
161  CombiningKind::MINUI,
162  CombiningKind::MINSI,
163  CombiningKind::MINF,
164  CombiningKind::MAXUI,
165  CombiningKind::MAXSI,
166  CombiningKind::MAXF,
167  CombiningKind::AND,
168  CombiningKind::OR,
169  CombiningKind::XOR,
170  // clang-format on
171 };
172 
173 void CombiningKindAttr::print(AsmPrinter &printer) const {
174  printer << "<";
175  auto kinds = llvm::make_filter_range(combiningKindsList, [&](auto kind) {
176  return bitEnumContains(this->getKind(), kind);
177  });
178  llvm::interleaveComma(kinds, printer,
179  [&](auto kind) { printer << stringifyEnum(kind); });
180  printer << ">";
181 }
182 
184  if (failed(parser.parseLess()))
185  return {};
186 
187  StringRef elemName;
188  if (failed(parser.parseKeyword(&elemName)))
189  return {};
190 
191  auto kind = symbolizeCombiningKind(elemName);
192  if (!kind) {
193  parser.emitError(parser.getNameLoc(), "Unknown combining kind: ")
194  << elemName;
195  return {};
196  }
197 
198  if (failed(parser.parseGreater()))
199  return {};
200 
201  return CombiningKindAttr::get(kind.getValue(), parser.getContext());
202 }
203 
205  Type type) const {
206  StringRef attrKind;
207  if (parser.parseKeyword(&attrKind))
208  return {};
209 
210  if (attrKind == "kind")
211  return CombiningKindAttr::parse(parser, {});
212 
213  parser.emitError(parser.getNameLoc(), "Unknown attribute type: ") << attrKind;
214  return {};
215 }
216 
217 void VectorDialect::printAttribute(Attribute attr,
218  DialectAsmPrinter &os) const {
219  if (auto ck = attr.dyn_cast<CombiningKindAttr>()) {
220  os << "kind";
221  ck.print(os);
222  return;
223  }
224  llvm_unreachable("Unknown attribute type");
225 }
226 
227 //===----------------------------------------------------------------------===//
228 // VectorDialect
229 //===----------------------------------------------------------------------===//
230 
231 void VectorDialect::initialize() {
232  addAttributes<CombiningKindAttr>();
233 
234  addOperations<
235 #define GET_OP_LIST
236 #include "mlir/Dialect/Vector/VectorOps.cpp.inc"
237  >();
238 }
239 
240 /// Materialize a single constant operation from a given attribute value with
241 /// the desired resultant type.
243  Attribute value, Type type,
244  Location loc) {
245  return builder.create<arith::ConstantOp>(loc, type, value);
246 }
247 
249  return builder.getIntegerType(64);
250 }
251 
253  ArrayRef<int64_t> values) {
254  return builder.getI64ArrayAttr(values);
255 }
256 
257 //===----------------------------------------------------------------------===//
258 // MultiDimReductionOp
259 //===----------------------------------------------------------------------===//
260 
261 void vector::MultiDimReductionOp::build(OpBuilder &builder,
262  OperationState &result, Value source,
263  ArrayRef<bool> reductionMask,
264  CombiningKind kind) {
265  result.addOperands(source);
266  auto sourceVectorType = source.getType().cast<VectorType>();
267  auto targetType = MultiDimReductionOp::inferDestType(
268  sourceVectorType.getShape(), reductionMask,
269  sourceVectorType.getElementType());
270  result.addTypes(targetType);
271 
272  SmallVector<int64_t> reductionDims;
273  for (const auto &en : llvm::enumerate(reductionMask))
274  if (en.value())
275  reductionDims.push_back(en.index());
276  result.addAttribute(getReductionDimsAttrName(),
277  builder.getI64ArrayAttr(reductionDims));
278  result.addAttribute(getKindAttrName(),
279  CombiningKindAttr::get(kind, builder.getContext()));
280 }
281 
282 static LogicalResult verify(MultiDimReductionOp op) {
283  auto reductionMask = op.getReductionMask();
284  auto targetType = MultiDimReductionOp::inferDestType(
285  op.getSourceVectorType().getShape(), reductionMask,
286  op.getSourceVectorType().getElementType());
287  // TODO: update to support 0-d vectors when available.
288  if (targetType != op.getDestType())
289  return op.emitError("invalid output vector type: ")
290  << op.getDestType() << " (expected: " << targetType << ")";
291  return success();
292 }
293 
294 OpFoldResult MultiDimReductionOp::fold(ArrayRef<Attribute> operands) {
295  // Single parallel dim, this is a noop.
296  if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
297  return source();
298  return {};
299 }
300 
301 //===----------------------------------------------------------------------===//
302 // ReductionOp
303 //===----------------------------------------------------------------------===//
304 
305 static LogicalResult verify(ReductionOp op) {
306  // Verify for 1-D vector.
307  int64_t rank = op.getVectorType().getRank();
308  if (rank != 1)
309  return op.emitOpError("unsupported reduction rank: ") << rank;
310 
311  // Verify supported reduction kind.
312  StringRef strKind = op.kind();
313  auto maybeKind = symbolizeCombiningKind(strKind);
314  if (!maybeKind)
315  return op.emitOpError("unknown reduction kind: ") << strKind;
316 
317  Type eltType = op.dest().getType();
318  if (!isSupportedCombiningKind(*maybeKind, eltType))
319  return op.emitOpError("unsupported reduction type '")
320  << eltType << "' for kind '" << op.kind() << "'";
321 
322  // Verify optional accumulator.
323  if (!op.acc().empty()) {
324  if (strKind != "add" && strKind != "mul")
325  return op.emitOpError("no accumulator for reduction kind: ") << strKind;
326  if (!eltType.isa<FloatType>())
327  return op.emitOpError("no accumulator for type: ") << eltType;
328  }
329 
330  return success();
331 }
332 
334  OperationState &result) {
335  SmallVector<OpAsmParser::OperandType, 2> operandsInfo;
336  Type redType;
337  Type resType;
338  Attribute attr;
339  if (parser.parseAttribute(attr, "kind", result.attributes) ||
340  parser.parseComma() || parser.parseOperandList(operandsInfo) ||
341  parser.parseColonType(redType) ||
342  parser.parseKeywordType("into", resType) ||
343  (!operandsInfo.empty() &&
344  parser.resolveOperand(operandsInfo[0], redType, result.operands)) ||
345  (operandsInfo.size() > 1 &&
346  parser.resolveOperand(operandsInfo[1], resType, result.operands)) ||
347  parser.addTypeToList(resType, result.types))
348  return failure();
349  if (operandsInfo.empty() || operandsInfo.size() > 2)
350  return parser.emitError(parser.getNameLoc(),
351  "unsupported number of operands");
352  return success();
353 }
354 
355 static void print(OpAsmPrinter &p, ReductionOp op) {
356  p << " \"" << op.kind() << "\", " << op.vector();
357  if (!op.acc().empty())
358  p << ", " << op.acc();
359  p << " : " << op.vector().getType() << " into " << op.dest().getType();
360 }
361 
363  OpBuilder &builder, Location loc,
364  Value vector) {
365  Type scalarType = vector.getType().cast<ShapedType>().getElementType();
366  switch (op) {
367  case arith::AtomicRMWKind::addf:
368  case arith::AtomicRMWKind::addi:
369  return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
370  builder.getStringAttr("add"),
371  vector, ValueRange{});
372  case arith::AtomicRMWKind::mulf:
373  case arith::AtomicRMWKind::muli:
374  return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
375  builder.getStringAttr("mul"),
376  vector, ValueRange{});
377  case arith::AtomicRMWKind::minf:
378  return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
379  builder.getStringAttr("minf"),
380  vector, ValueRange{});
381  case arith::AtomicRMWKind::mins:
382  return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
383  builder.getStringAttr("minsi"),
384  vector, ValueRange{});
385  case arith::AtomicRMWKind::minu:
386  return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
387  builder.getStringAttr("minui"),
388  vector, ValueRange{});
389  case arith::AtomicRMWKind::maxf:
390  return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
391  builder.getStringAttr("maxf"),
392  vector, ValueRange{});
393  case arith::AtomicRMWKind::maxs:
394  return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
395  builder.getStringAttr("maxsi"),
396  vector, ValueRange{});
397  case arith::AtomicRMWKind::maxu:
398  return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
399  builder.getStringAttr("maxui"),
400  vector, ValueRange{});
401  // TODO: Add remaining reduction operations.
402  default:
403  (void)emitOptionalError(loc, "Reduction operation type not supported");
404  break;
405  }
406  return nullptr;
407 }
408 
409 //===----------------------------------------------------------------------===//
410 // ContractionOp
411 //===----------------------------------------------------------------------===//
412 
413 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
414  Value lhs, Value rhs, Value acc,
415  ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
416  ArrayRef<StringRef> iteratorTypes) {
417  result.addOperands({lhs, rhs, acc});
418  result.addTypes(acc.getType());
420  builder.getAffineMapArrayAttr(
421  AffineMap::inferFromExprList(indexingExprs)));
423  builder.getStrArrayAttr(iteratorTypes));
424 }
425 
426 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
427  Value lhs, Value rhs, Value acc,
428  ArrayAttr indexingMaps,
429  ArrayAttr iteratorTypes) {
430  result.addOperands({lhs, rhs, acc});
431  result.addTypes(acc.getType());
432  result.addAttribute(getIndexingMapsAttrName(), indexingMaps);
433  result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
434  result.addAttribute(ContractionOp::getKindAttrName(),
435  CombiningKindAttr::get(ContractionOp::getDefaultKind(),
436  builder.getContext()));
437 }
438 
440  OperationState &result) {
441  OpAsmParser::OperandType lhsInfo;
442  OpAsmParser::OperandType rhsInfo;
443  OpAsmParser::OperandType accInfo;
444  SmallVector<OpAsmParser::OperandType, 2> masksInfo;
445  SmallVector<Type, 2> types;
446  Type resultType;
447  auto loc = parser.getCurrentLocation();
448  DictionaryAttr dictAttr;
449  // TODO: Unify linalg op attribute parsing.
450  if (parser.parseAttribute(dictAttr, "_", result.attributes) ||
451  parser.parseOperand(lhsInfo) || parser.parseComma() ||
452  parser.parseOperand(rhsInfo) || parser.parseComma() ||
453  parser.parseOperand(accInfo) ||
454  parser.parseTrailingOperandList(masksInfo) ||
455  parser.parseOptionalAttrDict(result.attributes) ||
456  parser.parseColonTypeList(types) ||
457  parser.parseKeywordType("into", resultType) ||
458  parser.resolveOperand(lhsInfo, types[0], result.operands) ||
459  parser.resolveOperand(rhsInfo, types[1], result.operands) ||
460  parser.resolveOperand(accInfo, resultType, result.operands) ||
461  parser.addTypeToList(resultType, result.types))
462  return failure();
463  result.attributes.assign(dictAttr.getValue().begin(),
464  dictAttr.getValue().end());
465  if (!result.attributes.get(ContractionOp::getKindAttrName())) {
466  result.addAttribute(ContractionOp::getKindAttrName(),
467  CombiningKindAttr::get(ContractionOp::getDefaultKind(),
468  result.getContext()));
469  }
470  if (masksInfo.empty())
471  return success();
472  if (masksInfo.size() != 2)
473  return parser.emitError(parser.getNameLoc(),
474  "expected zero or exactly 2 vector mask operands");
475  auto lhsType = types[0].cast<VectorType>();
476  auto rhsType = types[1].cast<VectorType>();
477  auto maskElementType = parser.getBuilder().getI1Type();
478  std::array<Type, 2> maskTypes = {
479  VectorType::Builder(lhsType).setElementType(maskElementType),
480  VectorType::Builder(rhsType).setElementType(maskElementType)};
481  if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
482  return failure();
483  return success();
484 }
485 
486 static void print(OpAsmPrinter &p, ContractionOp op) {
487  // TODO: Unify printing code with linalg ops.
488  auto attrNames = op.getTraitAttrNames();
489  llvm::StringSet<> traitAttrsSet;
490  traitAttrsSet.insert(attrNames.begin(), attrNames.end());
491  SmallVector<NamedAttribute, 8> attrs;
492  for (auto attr : op->getAttrs())
493  if (traitAttrsSet.count(attr.getName().strref()) > 0)
494  attrs.push_back(attr);
495 
496  auto dictAttr = DictionaryAttr::get(op.getContext(), attrs);
497  p << " " << dictAttr << " " << op.lhs() << ", ";
498  p << op.rhs() << ", " << op.acc();
499  if (op.masks().size() == 2)
500  p << ", " << op.masks();
501 
502  p.printOptionalAttrDict(op->getAttrs(), attrNames);
503  p << " : " << op.lhs().getType() << ", " << op.rhs().getType() << " into "
504  << op.getResultType();
505 }
506 
507 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
508  const std::vector<std::pair<int64_t, int64_t>> &map) {
509  for (auto &dimPair : map) {
510  if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
511  dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
512  lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
513  return false;
514  }
515  return true;
516 }
517 
519  ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
520  Type resType,
521  const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
522  const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
523  DenseSet<int64_t> lhsContractingDimSet;
524  DenseSet<int64_t> rhsContractingDimSet;
525  for (auto &dimPair : contractingDimMap) {
526  lhsContractingDimSet.insert(dimPair.first);
527  rhsContractingDimSet.insert(dimPair.second);
528  }
529  DenseSet<int64_t> rhsBatchDimSet;
530  for (auto &dimPair : batchDimMap)
531  rhsBatchDimSet.insert(dimPair.second);
532 
533  // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
534  SmallVector<int64_t, 4> expectedResultDims;
535  for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
536  if (lhsContractingDimSet.count(i) > 0)
537  continue;
538  expectedResultDims.push_back(lhsType.getDimSize(i));
539  }
540 
541  // Add free dimensions from 'rhsType' to 'expectedResultDims'.
542  for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
543  if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
544  continue;
545  expectedResultDims.push_back(rhsType.getDimSize(i));
546  }
547 
548  // Verify 'expectedResultDims'.
549  if (expectedResultDims.empty()) {
550  // No batch or free dimension implies a scalar result.
551  if (resType.isa<VectorType>() || accType.isa<VectorType>())
552  return op.emitOpError("invalid accumulator/result vector shape");
553  } else {
554  // At least one batch or free dimension implies a vector result.
555  auto resVectorType = resType.dyn_cast<VectorType>();
556  auto accVectorType = accType.dyn_cast<VectorType>();
557  if (!resVectorType || !accVectorType)
558  return op.emitOpError("invalid accumulator/result vector shape");
559 
560  // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
561  // types fully define the result vector type. This assumes the affine maps
562  // are well-formed, which must have been verified already.
563  MLIRContext *ctx = op.getContext();
564  AffineMap lhsMap = op.getIndexingMaps()[0];
565  AffineMap rhsMap = op.getIndexingMaps()[1];
566  SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
567  for (auto pair :
568  {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
569  VectorType v = pair.first;
570  auto map = pair.second;
571  for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
572  unsigned pos = map.getDimPosition(idx);
573  if (!extents[pos])
574  extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
575  }
576  }
577  assert(llvm::all_of(extents, [](AffineExpr e) { return e; }) &&
578  "expected extent along all dimensions.");
579 
580  AffineMap resMap = op.getIndexingMaps()[2];
581  auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
582  /*symCount=*/0, extents, ctx);
583  // Compose the resMap with the extentsMap, which is a constant map.
584  AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
585  assert(llvm::all_of(
586  expectedMap.getResults(),
587  [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) &&
588  "expected constant extent along all dimensions.");
589  // Extract the expected shape and build the type.
590  auto expectedShape = llvm::to_vector<4>(
591  llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
592  return e.cast<AffineConstantExpr>().getValue();
593  }));
594  auto expected =
595  VectorType::get(expectedShape, resVectorType.getElementType());
596  if (resVectorType != expected || accVectorType != expected)
597  return op.emitOpError(
598  "invalid accumulator/result vector shape, expected: ")
599  << expected;
600  }
601  return success();
602 }
603 
604 static LogicalResult verify(ContractionOp op) {
605  auto lhsType = op.getLhsType();
606  auto rhsType = op.getRhsType();
607  auto accType = op.getAccType();
608  auto resType = op.getResultType();
609 
610  // Verify that an indexing map was specified for each vector operand.
611  if (op.indexing_maps().size() != 3)
612  return op.emitOpError("expected an indexing map for each vector operand");
613 
614  // Verify that each index map has 'numIterators' inputs, no symbols, and
615  // that the number of map outputs equals the rank of its associated
616  // vector operand.
617  unsigned numIterators = op.iterator_types().getValue().size();
618  for (const auto &it : llvm::enumerate(op.indexing_maps())) {
619  auto index = it.index();
620  auto map = it.value().cast<AffineMapAttr>().getValue();
621  if (map.getNumSymbols() != 0)
622  return op.emitOpError("expected indexing map ")
623  << index << " to have no symbols";
624  auto vectorType = op.getOperand(index).getType().dyn_cast<VectorType>();
625  unsigned rank = vectorType ? vectorType.getShape().size() : 0;
626  // Verify that the map has the right number of inputs, outputs, and indices.
627  // This also correctly accounts for (..) -> () for rank-0 results.
628  if (map.getNumDims() != numIterators)
629  return op.emitOpError("expected indexing map ")
630  << index << " to have " << numIterators << " number of inputs";
631  if (map.getNumResults() != rank)
632  return op.emitOpError("expected indexing map ")
633  << index << " to have " << rank << " number of outputs";
634  if (!map.isProjectedPermutation())
635  return op.emitOpError("expected indexing map ")
636  << index << " to be a projected permutation of its inputs";
637  }
638 
639  auto contractingDimMap = op.getContractingDimMap();
640  auto batchDimMap = op.getBatchDimMap();
641 
642  // Verify at least one contracting dimension pair was specified.
643  if (contractingDimMap.empty())
644  return op.emitOpError("expected at least one contracting dimension pair");
645 
646  // Verify contracting dimension map was properly constructed.
647  if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
648  return op.emitOpError("invalid contracting dimension map");
649 
650  // Verify batch dimension map was properly constructed.
651  if (!verifyDimMap(lhsType, rhsType, batchDimMap))
652  return op.emitOpError("invalid batch dimension map");
653 
654  // Verify 'accType' and 'resType' shape.
655  if (failed(verifyOutputShape(op, lhsType, rhsType, accType, resType,
656  contractingDimMap, batchDimMap)))
657  return failure();
658 
659  // Verify that either two vector masks are set or none are set.
660  auto lhsMaskType = op.getLHSVectorMaskType();
661  auto rhsMaskType = op.getRHSVectorMaskType();
662  if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType))
663  return op.emitOpError("invalid number of vector masks specified");
664  if (lhsMaskType && rhsMaskType) {
665  // Verify mask rank == argument rank.
666  if (lhsMaskType.getShape().size() != lhsType.getShape().size() ||
667  rhsMaskType.getShape().size() != rhsType.getShape().size())
668  return op.emitOpError("invalid vector mask rank");
669  }
670 
671  // Verify supported combining kind.
672  auto vectorType = resType.dyn_cast<VectorType>();
673  auto elementType = vectorType ? vectorType.getElementType() : resType;
674  if (!isSupportedCombiningKind(op.kind(), elementType))
675  return op.emitOpError("unsupported contraction type");
676 
677  return success();
678 }
679 
680 ArrayRef<StringRef> ContractionOp::getTraitAttrNames() {
681  static constexpr StringRef names[3] = {getIndexingMapsAttrName(),
683  ContractionOp::getKindAttrName()};
684  return llvm::makeArrayRef(names);
685 }
686 
687 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
688  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
689  if (targetExpr == map.getResult(i))
690  return i;
691  return -1;
692 }
693 
694 static std::vector<std::pair<int64_t, int64_t>>
695 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
696  StringRef targetIteratorTypeName, MLIRContext *context) {
697  std::vector<std::pair<int64_t, int64_t>> dimMap;
698  for (const auto &it : llvm::enumerate(iteratorTypes)) {
699  auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
700  if (iteratorTypeName != targetIteratorTypeName)
701  continue;
702  // Search lhs/rhs map results for 'targetExpr'.
703  auto targetExpr = getAffineDimExpr(it.index(), context);
704  int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
705  int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
706  if (lhsDim >= 0 && rhsDim >= 0)
707  dimMap.emplace_back(lhsDim, rhsDim);
708  }
709  return dimMap;
710 }
711 
712 void ContractionOp::getIterationBounds(
713  SmallVectorImpl<int64_t> &iterationBounds) {
714  auto lhsShape = getLhsType().getShape();
715  auto resVectorType = getResultType().dyn_cast<VectorType>();
716  SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
717  SmallVector<int64_t, 2> iterationShape;
718  for (const auto &it : llvm::enumerate(iterator_types())) {
719  // Search lhs/rhs map results for 'targetExpr'.
720  auto targetExpr = getAffineDimExpr(it.index(), getContext());
721  auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
722  if (iteratorTypeName == getReductionIteratorTypeName()) {
723  // Get reduction dim size from lhs shape (same size in rhsShape).
724  int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
725  assert(lhsDimIndex >= 0);
726  iterationBounds.push_back(lhsShape[lhsDimIndex]);
727  continue;
728  }
729  // Get parallel dimension size from result shape.
730  int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
731  assert(resDimIndex >= 0);
732  assert(resVectorType != nullptr);
733  iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
734  }
735 }
736 
737 void ContractionOp::getIterationIndexMap(
738  std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
739  unsigned numMaps = indexing_maps().getValue().size();
740  iterationIndexMap.resize(numMaps);
741  for (const auto &it : llvm::enumerate(indexing_maps())) {
742  auto index = it.index();
743  auto map = it.value().cast<AffineMapAttr>().getValue();
744  for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
745  auto dim = map.getResult(i).cast<AffineDimExpr>();
746  iterationIndexMap[index][dim.getPosition()] = i;
747  }
748  }
749 }
750 
751 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
752  SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
753  return getDimMap(indexingMaps, iterator_types(),
754  getReductionIteratorTypeName(), getContext());
755 }
756 
757 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
758  SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
759  return getDimMap(indexingMaps, iterator_types(),
760  getParallelIteratorTypeName(), getContext());
761 }
762 
763 SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
764  return llvm::to_vector<4>(
765  llvm::map_range(indexing_maps().getValue(), [](Attribute mapAttr) {
766  return mapAttr.cast<AffineMapAttr>().getValue();
767  }));
768 }
769 
770 Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
771  SmallVector<int64_t, 4> shape;
772  getIterationBounds(shape);
773  return shape;
774 }
775 
776 /// Return a fused vector::ContractionOp which represents a patterns such as:
777 ///
778 /// ```mlir
779 /// %c0 = vector.constant 0: ...
780 /// %c = vector.contract %a, %b, %c0: ...
781 /// %e = add %c, %d: ...
782 /// ```
783 ///
784 /// by:
785 ///
786 /// ```mlir
787 /// %e = vector.contract %a, %b, %d: ...
788 /// ```
789 ///
790 /// Return null if the canonicalization does not apply.
791 // TODO: This should be a folding of Add into Contract in core but while they
792 // live in different dialects, it is not possible without unnatural
793 // dependencies.
794 template <typename AddOpType>
795 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
797 
799  PatternRewriter &rewriter) const override {
800  auto canonicalize = [&](Value maybeContraction,
801  Value otherOperand) -> vector::ContractionOp {
802  vector::ContractionOp contractionOp =
803  dyn_cast_or_null<vector::ContractionOp>(
804  maybeContraction.getDefiningOp());
805  if (!contractionOp)
806  return vector::ContractionOp();
807  if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
808  contractionOp.acc().getDefiningOp())) {
809  if (maybeZero.getValue() ==
810  rewriter.getZeroAttr(contractionOp.acc().getType())) {
812  bvm.map(contractionOp.acc(), otherOperand);
813  auto newContraction =
814  cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
815  rewriter.replaceOp(addOp, newContraction.getResult());
816  return newContraction;
817  }
818  }
819  return vector::ContractionOp();
820  };
821 
822  Value a = addOp->getOperand(0), b = addOp->getOperand(1);
823  vector::ContractionOp contract = canonicalize(a, b);
824  contract = contract ? contract : canonicalize(b, a);
825  return contract ? success() : failure();
826  }
827 };
828 
829 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
830  MLIRContext *context) {
833 }
834 
835 //===----------------------------------------------------------------------===//
836 // ExtractElementOp
837 //===----------------------------------------------------------------------===//
838 
839 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
840  Value source) {
841  result.addOperands({source});
842  result.addTypes(source.getType().cast<VectorType>().getElementType());
843 }
844 
845 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
846  Value source, Value position) {
847  result.addOperands({source, position});
848  result.addTypes(source.getType().cast<VectorType>().getElementType());
849 }
850 
851 static LogicalResult verify(vector::ExtractElementOp op) {
852  VectorType vectorType = op.getVectorType();
853  if (vectorType.getRank() == 0) {
854  if (op.position())
855  return op.emitOpError("expected position to be empty with 0-D vector");
856  return success();
857  }
858  if (vectorType.getRank() != 1)
859  return op.emitOpError("unexpected >1 vector rank");
860  if (!op.position())
861  return op.emitOpError("expected position for 1-D vector");
862  return success();
863 }
864 
865 //===----------------------------------------------------------------------===//
866 // ExtractOp
867 //===----------------------------------------------------------------------===//
868 
870  ArrayAttr position) {
871  if (static_cast<int64_t>(position.size()) == vectorType.getRank())
872  return vectorType.getElementType();
873  return VectorType::get(vectorType.getShape().drop_front(position.size()),
874  vectorType.getElementType());
875 }
876 
877 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
878  Value source, ArrayRef<int64_t> position) {
879  result.addOperands(source);
880  auto positionAttr = getVectorSubscriptAttr(builder, position);
881  result.addTypes(inferExtractOpResultType(source.getType().cast<VectorType>(),
882  positionAttr));
883  result.addAttribute(getPositionAttrName(), positionAttr);
884 }
885 
886 // Convenience builder which assumes the values are constant indices.
887 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
888  Value source, ValueRange position) {
889  SmallVector<int64_t, 4> positionConstants =
890  llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
892  }));
893  build(builder, result, source, positionConstants);
894 }
895 
896 static void print(OpAsmPrinter &p, vector::ExtractOp op) {
897  p << " " << op.vector() << op.position();
898  p.printOptionalAttrDict(op->getAttrs(), {"position"});
899  p << " : " << op.vector().getType();
900 }
901 
903  llvm::SMLoc attributeLoc, typeLoc;
904  NamedAttrList attrs;
906  Type type;
907  Attribute attr;
908  if (parser.parseOperand(vector) || parser.getCurrentLocation(&attributeLoc) ||
909  parser.parseAttribute(attr, "position", attrs) ||
910  parser.parseOptionalAttrDict(attrs) ||
911  parser.getCurrentLocation(&typeLoc) || parser.parseColonType(type))
912  return failure();
913 
914  auto vectorType = type.dyn_cast<VectorType>();
915  if (!vectorType)
916  return parser.emitError(typeLoc, "expected vector type");
917 
918  auto positionAttr = attr.dyn_cast<ArrayAttr>();
919  if (!positionAttr ||
920  static_cast<int64_t>(positionAttr.size()) > vectorType.getRank())
921  return parser.emitError(
922  attributeLoc,
923  "expected position attribute of rank smaller than vector rank");
924 
925  Type resType = inferExtractOpResultType(vectorType, positionAttr);
926  result.attributes = attrs;
927  return failure(parser.resolveOperand(vector, type, result.operands) ||
928  parser.addTypeToList(resType, result.types));
929 }
930 
931 static LogicalResult verify(vector::ExtractOp op) {
932  auto positionAttr = op.position().getValue();
933  if (positionAttr.size() > static_cast<unsigned>(op.getVectorType().getRank()))
934  return op.emitOpError(
935  "expected position attribute of rank smaller than vector rank");
936  for (const auto &en : llvm::enumerate(positionAttr)) {
937  auto attr = en.value().dyn_cast<IntegerAttr>();
938  if (!attr || attr.getInt() < 0 ||
939  attr.getInt() >= op.getVectorType().getDimSize(en.index()))
940  return op.emitOpError("expected position attribute #")
941  << (en.index() + 1)
942  << " to be a non-negative integer smaller than the corresponding "
943  "vector dimension";
944  }
945  return success();
946 }
947 
948 template <typename IntType>
949 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
950  return llvm::to_vector<4>(llvm::map_range(
951  arrayAttr.getAsRange<IntegerAttr>(),
952  [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
953 }
954 
955 /// Fold the result of chains of ExtractOp in place by simply concatenating the
956 /// positions.
957 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
958  if (!extractOp.vector().getDefiningOp<ExtractOp>())
959  return failure();
960 
961  SmallVector<int64_t, 4> globalPosition;
962  ExtractOp currentOp = extractOp;
963  auto extrPos = extractVector<int64_t>(currentOp.position());
964  globalPosition.append(extrPos.rbegin(), extrPos.rend());
965  while (ExtractOp nextOp = currentOp.vector().getDefiningOp<ExtractOp>()) {
966  currentOp = nextOp;
967  auto extrPos = extractVector<int64_t>(currentOp.position());
968  globalPosition.append(extrPos.rbegin(), extrPos.rend());
969  }
970  extractOp.setOperand(currentOp.vector());
971  // OpBuilder is only used as a helper to build an I64ArrayAttr.
972  OpBuilder b(extractOp.getContext());
973  std::reverse(globalPosition.begin(), globalPosition.end());
974  extractOp->setAttr(ExtractOp::getPositionAttrName(),
975  b.getI64ArrayAttr(globalPosition));
976  return success();
977 }
978 
979 namespace {
980 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
981 /// Walk back a chain of InsertOp/TransposeOp until we hit a match.
982 /// Compose TransposeOp permutations as we walk back.
983 /// This helper class keeps an updated extraction position `extractPosition`
984 /// with extra trailing sentinels.
985 /// The sentinels encode the internal transposition status of the result vector.
986 /// As we iterate, extractPosition is permuted and updated.
987 class ExtractFromInsertTransposeChainState {
988 public:
989  ExtractFromInsertTransposeChainState(ExtractOp e);
990 
991  /// Iterate over producing insert and transpose ops until we find a fold.
992  Value fold();
993 
994 private:
995  /// Return true if the vector at position `a` is contained within the vector
996  /// at position `b`. Under insert/extract semantics, this is the same as `a`
997  /// is a prefix of `b`.
998  template <typename ContainerA, typename ContainerB>
999  bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1000  return a.size() <= b.size() &&
1001  std::equal(a.begin(), a.begin() + a.size(), b.begin());
1002  }
1003 
1004  /// Return true if the vector at position `a` intersects the vector at
1005  /// position `b`. Under insert/extract semantics, this is the same as equality
1006  /// of all entries of `a` that are >=0 with the corresponding entries of b.
1007  /// Comparison is on the common prefix (i.e. zip).
1008  template <typename ContainerA, typename ContainerB>
1009  bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1010  for (auto it : llvm::zip(a, b)) {
1011  if (std::get<0>(it) < 0 || std::get<0>(it) < 0)
1012  continue;
1013  if (std::get<0>(it) != std::get<1>(it))
1014  return false;
1015  }
1016  return true;
1017  }
1018 
1019  /// Folding is only possible in the absence of an internal permutation in the
1020  /// result vector.
1021  bool canFold() {
1022  return (sentinels ==
1023  makeArrayRef(extractPosition).drop_front(extractedRank));
1024  }
1025 
1026  // Helper to get the next defining op of interest.
1027  void updateStateForNextIteration(Value v) {
1028  nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1029  nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1030  };
1031 
1032  // Case 1. If we hit a transpose, just compose the map and iterate.
1033  // Invariant: insert + transpose do not change rank, we can always compose.
1034  LogicalResult handleTransposeOp();
1035 
1036  // Case 2: the insert position matches extractPosition exactly, early return.
1037  LogicalResult handleInsertOpWithMatchingPos(Value &res);
1038 
1039  /// Case 3: if the insert position is a prefix of extractPosition, extract a
1040  /// portion of the source of the insert.
1041  /// Example:
1042  /// ```
1043  /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1044  /// // extractPosition == [1, 2, 3]
1045  /// %ext = vector.extract %ins[1, 0]: vector<3x4x5>
1046  /// // can fold to vector.extract %source[0, 3]
1047  /// %ext = vector.extract %source[3]: vector<5x6>
1048  /// ```
1049  /// To traverse through %source, we need to set the leading dims to 0 and
1050  /// drop the extra leading dims.
1051  /// This method updates the internal state.
1052  LogicalResult handleInsertOpWithPrefixPos(Value &res);
1053 
1054  /// Try to fold in place to extract(source, extractPosition) and return the
1055  /// folded result. Return null if folding is not possible (e.g. due to an
1056  /// internal tranposition in the result).
1057  Value tryToFoldExtractOpInPlace(Value source);
1058 
1059  ExtractOp extractOp;
1060  int64_t vectorRank;
1061  int64_t extractedRank;
1062 
1063  InsertOp nextInsertOp;
1064  TransposeOp nextTransposeOp;
1065 
1066  /// Sentinel values that encode the internal permutation status of the result.
1067  /// They are set to (-1, ... , -k) at the beginning and appended to
1068  /// `extractPosition`.
1069  /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1070  /// ensure that there is no internal transposition.
1071  /// Internal transposition cannot be accounted for with a folding pattern.
1072  // TODO: We could relax the internal transposition with an extra transposition
1073  // operation in a future canonicalizer.
1074  SmallVector<int64_t> sentinels;
1075  SmallVector<int64_t> extractPosition;
1076 };
1077 } // namespace
1078 
1079 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1080  ExtractOp e)
1081  : extractOp(e), vectorRank(extractOp.getVectorType().getRank()),
1082  extractedRank(extractOp.position().size()) {
1083  assert(vectorRank >= extractedRank && "extracted pos overflow");
1084  sentinels.reserve(vectorRank - extractedRank);
1085  for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1086  sentinels.push_back(-(i + 1));
1087  extractPosition = extractVector<int64_t>(extractOp.position());
1088  llvm::append_range(extractPosition, sentinels);
1089 }
1090 
1091 // Case 1. If we hit a transpose, just compose the map and iterate.
1092 // Invariant: insert + transpose do not change rank, we can always compose.
1093 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1094  if (!nextTransposeOp)
1095  return failure();
1096  auto permutation = extractVector<unsigned>(nextTransposeOp.transp());
1098  AffineMap::getPermutationMap(permutation, extractOp.getContext()));
1099  extractPosition = applyPermutationMap(m, makeArrayRef(extractPosition));
1100  return success();
1101 }
1102 
1103 // Case 2: the insert position matches extractPosition exactly, early return.
1105 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1106  Value &res) {
1107  auto insertedPos = extractVector<int64_t>(nextInsertOp.position());
1108  if (makeArrayRef(insertedPos) !=
1109  llvm::makeArrayRef(extractPosition).take_front(extractedRank))
1110  return failure();
1111  // Case 2.a. early-exit fold.
1112  res = nextInsertOp.source();
1113  // Case 2.b. if internal transposition is present, canFold will be false.
1114  return success();
1115 }
1116 
1117 /// Case 3: if inserted position is a prefix of extractPosition,
1118 /// extract a portion of the source of the insertion.
1119 /// This method updates the internal state.
1121 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1122  auto insertedPos = extractVector<int64_t>(nextInsertOp.position());
1123  if (!isContainedWithin(insertedPos, extractPosition))
1124  return failure();
1125  // Set leading dims to zero.
1126  std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1127  // Drop extra leading dims.
1128  extractPosition.erase(extractPosition.begin(),
1129  extractPosition.begin() + insertedPos.size());
1130  extractedRank = extractPosition.size() - sentinels.size();
1131  // Case 3.a. early-exit fold (break and delegate to post-while path).
1132  res = nextInsertOp.source();
1133  // Case 3.b. if internal transposition is present, canFold will be false.
1134  return success();
1135 }
1136 
1137 /// Try to fold in place to extract(source, extractPosition) and return the
1138 /// folded result. Return null if folding is not possible (e.g. due to an
1139 /// internal tranposition in the result).
1140 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1141  Value source) {
1142  // If we can't fold (either internal transposition, or nothing to fold), bail.
1143  bool nothingToFold = (source == extractOp.vector());
1144  if (nothingToFold || !canFold())
1145  return Value();
1146  // Otherwise, fold by updating the op inplace and return its result.
1147  OpBuilder b(extractOp.getContext());
1148  extractOp->setAttr(
1149  extractOp.positionAttrName(),
1150  b.getI64ArrayAttr(
1151  makeArrayRef(extractPosition).take_front(extractedRank)));
1152  extractOp.vectorMutable().assign(source);
1153  return extractOp.getResult();
1154 }
1155 
1156 /// Iterate over producing insert and transpose ops until we find a fold.
1157 Value ExtractFromInsertTransposeChainState::fold() {
1158  Value valueToExtractFrom = extractOp.vector();
1159  updateStateForNextIteration(valueToExtractFrom);
1160  while (nextInsertOp || nextTransposeOp) {
1161  // Case 1. If we hit a transpose, just compose the map and iterate.
1162  // Invariant: insert + transpose do not change rank, we can always compose.
1163  if (succeeded(handleTransposeOp())) {
1164  valueToExtractFrom = nextTransposeOp.vector();
1165  updateStateForNextIteration(valueToExtractFrom);
1166  continue;
1167  }
1168 
1169  Value result;
1170  // Case 2: the position match exactly.
1171  if (succeeded(handleInsertOpWithMatchingPos(result)))
1172  return result;
1173 
1174  // Case 3: if the inserted position is a prefix of extractPosition, we can
1175  // just extract a portion of the source of the insert.
1176  if (succeeded(handleInsertOpWithPrefixPos(result)))
1177  return tryToFoldExtractOpInPlace(result);
1178 
1179  // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1180  // values. This is a more difficult case and we bail.
1181  auto insertedPos = extractVector<int64_t>(nextInsertOp.position());
1182  if (isContainedWithin(extractPosition, insertedPos) ||
1183  intersectsWhereNonNegative(extractPosition, insertedPos))
1184  return Value();
1185 
1186  // Case 5: No intersection, we forward the extract to insertOp.dest().
1187  valueToExtractFrom = nextInsertOp.dest();
1188  updateStateForNextIteration(valueToExtractFrom);
1189  }
1190  // If after all this we can fold, go for it.
1191  return tryToFoldExtractOpInPlace(valueToExtractFrom);
1192 }
1193 
1194 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1195 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1196  Operation *defOp = extractOp.vector().getDefiningOp();
1197  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1198  return Value();
1199  Value source = defOp->getOperand(0);
1200  if (extractOp.getType() == source.getType())
1201  return source;
1202  auto getRank = [](Type type) {
1203  return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
1204  };
1205  unsigned broadcastSrcRank = getRank(source.getType());
1206  unsigned extractResultRank = getRank(extractOp.getType());
1207  if (extractResultRank < broadcastSrcRank) {
1208  auto extractPos = extractVector<int64_t>(extractOp.position());
1209  unsigned rankDiff = broadcastSrcRank - extractResultRank;
1210  extractPos.erase(
1211  extractPos.begin(),
1212  std::next(extractPos.begin(), extractPos.size() - rankDiff));
1213  extractOp.setOperand(source);
1214  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1215  OpBuilder b(extractOp.getContext());
1216  extractOp->setAttr(ExtractOp::getPositionAttrName(),
1217  b.getI64ArrayAttr(extractPos));
1218  return extractOp.getResult();
1219  }
1220  return Value();
1221 }
1222 
1223 // Fold extractOp with source coming from ShapeCast op.
1224 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1225  auto shapeCastOp = extractOp.vector().getDefiningOp<vector::ShapeCastOp>();
1226  if (!shapeCastOp)
1227  return Value();
1228  // Get the nth dimension size starting from lowest dimension.
1229  auto getDimReverse = [](VectorType type, int64_t n) {
1230  return type.getShape().take_back(n + 1).front();
1231  };
1232  int64_t destinationRank =
1233  extractOp.getType().isa<VectorType>()
1234  ? extractOp.getType().cast<VectorType>().getRank()
1235  : 0;
1236  if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1237  return Value();
1238  if (destinationRank > 0) {
1239  auto destinationType = extractOp.getResult().getType().cast<VectorType>();
1240  for (int64_t i = 0; i < destinationRank; i++) {
1241  // The lowest dimension of of the destination must match the lowest
1242  // dimension of the shapecast op source.
1243  // TODO: This case could be support in a canonicalization pattern.
1244  if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1245  getDimReverse(destinationType, i))
1246  return Value();
1247  }
1248  }
1249  // Extract the strides associated with the extract op vector source. Then use
1250  // this to calculate a linearized position for the extract.
1251  auto extractedPos = extractVector<int64_t>(extractOp.position());
1252  std::reverse(extractedPos.begin(), extractedPos.end());
1253  SmallVector<int64_t, 4> strides;
1254  int64_t stride = 1;
1255  for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1256  strides.push_back(stride);
1257  stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank);
1258  }
1259 
1260  int64_t position = linearize(extractedPos, strides);
1261  // Then extract the strides associated to the shapeCast op vector source and
1262  // delinearize the position using those strides.
1263  SmallVector<int64_t, 4> newStrides;
1264  int64_t numDimension =
1265  shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1266  stride = 1;
1267  for (int64_t i = 0; i < numDimension; i++) {
1268  newStrides.push_back(stride);
1269  stride *=
1270  getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1271  }
1272  std::reverse(newStrides.begin(), newStrides.end());
1273  SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
1274  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1275  OpBuilder b(extractOp.getContext());
1276  extractOp->setAttr(ExtractOp::getPositionAttrName(),
1277  b.getI64ArrayAttr(newPosition));
1278  extractOp.setOperand(shapeCastOp.source());
1279  return extractOp.getResult();
1280 }
1281 
1282 /// Fold an ExtractOp from ExtractStridedSliceOp.
1283 static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1284  auto extractStridedSliceOp =
1285  extractOp.vector().getDefiningOp<vector::ExtractStridedSliceOp>();
1286  if (!extractStridedSliceOp)
1287  return Value();
1288  // Return if 'extractStridedSliceOp' has non-unit strides.
1289  if (extractStridedSliceOp.hasNonUnitStrides())
1290  return Value();
1291 
1292  // Trim offsets for dimensions fully extracted.
1293  auto sliceOffsets = extractVector<int64_t>(extractStridedSliceOp.offsets());
1294  while (!sliceOffsets.empty()) {
1295  size_t lastOffset = sliceOffsets.size() - 1;
1296  if (sliceOffsets.back() != 0 ||
1297  extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1298  extractStridedSliceOp.getVectorType().getDimSize(lastOffset))
1299  break;
1300  sliceOffsets.pop_back();
1301  }
1302  unsigned destinationRank = 0;
1303  if (auto vecType = extractOp.getType().dyn_cast<VectorType>())
1304  destinationRank = vecType.getRank();
1305  // The dimensions of the result need to be untouched by the
1306  // extractStridedSlice op.
1307  if (destinationRank >
1308  extractStridedSliceOp.getVectorType().getRank() - sliceOffsets.size())
1309  return Value();
1310  auto extractedPos = extractVector<int64_t>(extractOp.position());
1311  assert(extractedPos.size() >= sliceOffsets.size());
1312  for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1313  extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1314  extractOp.vectorMutable().assign(extractStridedSliceOp.vector());
1315  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1316  OpBuilder b(extractOp.getContext());
1317  extractOp->setAttr(ExtractOp::getPositionAttrName(),
1318  b.getI64ArrayAttr(extractedPos));
1319  return extractOp.getResult();
1320 }
1321 
1322 /// Fold extract_op fed from a chain of insertStridedSlice ops.
1324  int64_t destinationRank = op.getType().isa<VectorType>()
1325  ? op.getType().cast<VectorType>().getRank()
1326  : 0;
1327  auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>();
1328  while (insertOp) {
1329  int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1330  insertOp.getSourceVectorType().getRank();
1331  if (destinationRank > insertOp.getSourceVectorType().getRank())
1332  return Value();
1333  auto insertOffsets = extractVector<int64_t>(insertOp.offsets());
1334  auto extractOffsets = extractVector<int64_t>(op.position());
1335 
1336  if (llvm::any_of(insertOp.strides(), [](Attribute attr) {
1337  return attr.cast<IntegerAttr>().getInt() != 1;
1338  }))
1339  return Value();
1340  bool disjoint = false;
1341  SmallVector<int64_t, 4> offsetDiffs;
1342  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1343  int64_t start = insertOffsets[dim];
1344  int64_t size =
1345  (dim < insertRankDiff)
1346  ? 1
1347  : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1348  int64_t end = start + size;
1349  int64_t offset = extractOffsets[dim];
1350  // Check if the start of the extract offset is in the interval inserted.
1351  if (start <= offset && offset < end) {
1352  if (dim >= insertRankDiff)
1353  offsetDiffs.push_back(offset - start);
1354  continue;
1355  }
1356  disjoint = true;
1357  break;
1358  }
1359  // The extract element chunk overlap with the vector inserted.
1360  if (!disjoint) {
1361  // If any of the inner dimensions are only partially inserted we have a
1362  // partial overlap.
1363  int64_t srcRankDiff =
1364  insertOp.getSourceVectorType().getRank() - destinationRank;
1365  for (int64_t i = 0; i < destinationRank; i++) {
1366  if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1367  insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1368  insertRankDiff))
1369  return Value();
1370  }
1371  op.vectorMutable().assign(insertOp.source());
1372  // OpBuilder is only used as a helper to build an I64ArrayAttr.
1373  OpBuilder b(op.getContext());
1374  op->setAttr(ExtractOp::getPositionAttrName(),
1375  b.getI64ArrayAttr(offsetDiffs));
1376  return op.getResult();
1377  }
1378  // If the chunk extracted is disjoint from the chunk inserted, keep
1379  // looking in the insert chain.
1380  insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>();
1381  }
1382  return Value();
1383 }
1384 
1385 OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
1386  if (position().empty())
1387  return vector();
1389  return getResult();
1390  if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
1391  return res;
1392  if (auto res = foldExtractFromBroadcast(*this))
1393  return res;
1394  if (auto res = foldExtractFromShapeCast(*this))
1395  return res;
1396  if (auto val = foldExtractFromExtractStrided(*this))
1397  return val;
1398  if (auto val = foldExtractStridedOpFromInsertChain(*this))
1399  return val;
1400  return OpFoldResult();
1401 }
1402 
1403 namespace {
1404 
1405 // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
1406 class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
1407 public:
1409 
1410  LogicalResult matchAndRewrite(ExtractOp extractOp,
1411  PatternRewriter &rewriter) const override {
1412  Operation *defOp = extractOp.vector().getDefiningOp();
1413  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1414  return failure();
1415  Value source = defOp->getOperand(0);
1416  if (extractOp.getType() == source.getType())
1417  return failure();
1418  auto getRank = [](Type type) {
1419  return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
1420  };
1421  unsigned broadcastSrcRank = getRank(source.getType());
1422  unsigned extractResultRank = getRank(extractOp.getType());
1423  // We only consider the case where the rank of the source is smaller than
1424  // the rank of the extract dst. The other cases are handled in the folding
1425  // patterns.
1426  if (extractResultRank <= broadcastSrcRank)
1427  return failure();
1428  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1429  extractOp, extractOp.getType(), source);
1430  return success();
1431  }
1432 };
1433 
1434 // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
1435 class ExtractOpConstantFolder final : public OpRewritePattern<ExtractOp> {
1436 public:
1438 
1439  LogicalResult matchAndRewrite(ExtractOp extractOp,
1440  PatternRewriter &rewriter) const override {
1441  // Return if 'extractStridedSliceOp' operand is not defined by a
1442  // ConstantOp.
1443  auto constantOp = extractOp.vector().getDefiningOp<arith::ConstantOp>();
1444  if (!constantOp)
1445  return failure();
1446  auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
1447  if (!dense)
1448  return failure();
1449  Attribute newAttr = dense.getSplatValue<Attribute>();
1450  if (auto vecDstType = extractOp.getType().dyn_cast<VectorType>())
1451  newAttr = DenseElementsAttr::get(vecDstType, newAttr);
1452  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
1453  return success();
1454  }
1455 };
1456 
1457 } // namespace
1458 
1459 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1460  MLIRContext *context) {
1461  results.add<ExtractOpConstantFolder, ExtractOpFromBroadcast>(context);
1462 }
1463 
1464 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
1465  SmallVectorImpl<int64_t> &results) {
1466  for (auto attr : arrayAttr)
1467  results.push_back(attr.cast<IntegerAttr>().getInt());
1468 }
1469 
1470 //===----------------------------------------------------------------------===//
1471 // ExtractMapOp
1472 //===----------------------------------------------------------------------===//
1473 
1474 void ExtractMapOp::build(OpBuilder &builder, OperationState &result,
1475  Value vector, ValueRange ids,
1476  ArrayRef<int64_t> multiplicity,
1477  AffineMap permutationMap) {
1478  assert(ids.size() == multiplicity.size() &&
1479  ids.size() == permutationMap.getNumResults());
1480  assert(permutationMap.isProjectedPermutation());
1481  VectorType type = vector.getType().cast<VectorType>();
1482  SmallVector<int64_t, 4> newShape(type.getShape().begin(),
1483  type.getShape().end());
1484  for (unsigned i = 0, e = permutationMap.getNumResults(); i < e; i++) {
1485  AffineExpr expr = permutationMap.getResult(i);
1486  auto dim = expr.cast<AffineDimExpr>();
1487  newShape[dim.getPosition()] = newShape[dim.getPosition()] / multiplicity[i];
1488  }
1489  VectorType resultType = VectorType::get(newShape, type.getElementType());
1490  ExtractMapOp::build(builder, result, resultType, vector, ids);
1491 }
1492 
1493 static LogicalResult verify(ExtractMapOp op) {
1494  if (op.getSourceVectorType().getRank() != op.getResultType().getRank())
1495  return op.emitOpError(
1496  "expected source and destination vectors of same rank");
1497  unsigned numId = 0;
1498  for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; ++i) {
1499  if (op.getSourceVectorType().getDimSize(i) %
1500  op.getResultType().getDimSize(i) !=
1501  0)
1502  return op.emitOpError("source vector dimensions must be a multiple of "
1503  "destination vector dimensions");
1504  if (op.getSourceVectorType().getDimSize(i) !=
1505  op.getResultType().getDimSize(i))
1506  numId++;
1507  }
1508  if (numId != op.ids().size())
1509  return op.emitOpError("expected number of ids must match the number of "
1510  "dimensions distributed");
1511  return success();
1512 }
1513 
1514 OpFoldResult ExtractMapOp::fold(ArrayRef<Attribute> operands) {
1515  auto insert = vector().getDefiningOp<vector::InsertMapOp>();
1516  if (insert == nullptr || getType() != insert.vector().getType() ||
1517  ids() != insert.ids())
1518  return {};
1519  return insert.vector();
1520 }
1521 
1522 void ExtractMapOp::getMultiplicity(SmallVectorImpl<int64_t> &multiplicity) {
1523  assert(multiplicity.empty());
1524  for (unsigned i = 0, e = getSourceVectorType().getRank(); i < e; i++) {
1525  if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i))
1526  multiplicity.push_back(getSourceVectorType().getDimSize(i) /
1527  getResultType().getDimSize(i));
1528  }
1529 }
1530 
1531 template <typename MapOp>
1533  SmallVector<AffineExpr, 4> perm;
1534  // Check which dimension have a multiplicity greater than 1 and associated
1535  // them to the IDs in order.
1536  for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; i++) {
1537  if (op.getSourceVectorType().getDimSize(i) !=
1538  op.getResultType().getDimSize(i))
1539  perm.push_back(getAffineDimExpr(i, op.getContext()));
1540  }
1541  auto map = AffineMap::get(op.getSourceVectorType().getRank(), 0, perm,
1542  op.getContext());
1543  return map;
1544 }
1545 
1546 AffineMap ExtractMapOp::map() { return calculateImplicitMap(*this); }
1547 
1548 //===----------------------------------------------------------------------===//
1549 // FmaOp
1550 //===----------------------------------------------------------------------===//
1551 
1552 Optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
1553  return llvm::to_vector<4>(getVectorType().getShape());
1554 }
1555 
1556 //===----------------------------------------------------------------------===//
1557 // BroadcastOp
1558 //===----------------------------------------------------------------------===//
1559 
1561 mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
1562  std::pair<int, int> *mismatchingDims) {
1563  // Broadcast scalar to vector of the same element type.
1564  if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
1565  getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
1567  // From now on, only vectors broadcast.
1568  VectorType srcVectorType = srcType.dyn_cast<VectorType>();
1569  if (!srcVectorType)
1571 
1572  int64_t srcRank = srcVectorType.getRank();
1573  int64_t dstRank = dstVectorType.getRank();
1574  if (srcRank > dstRank)
1576  // Source has an exact match or singleton value for all trailing dimensions
1577  // (all leading dimensions are simply duplicated).
1578  int64_t lead = dstRank - srcRank;
1579  for (int64_t r = 0; r < srcRank; ++r) {
1580  int64_t srcDim = srcVectorType.getDimSize(r);
1581  int64_t dstDim = dstVectorType.getDimSize(lead + r);
1582  if (srcDim != 1 && srcDim != dstDim) {
1583  if (mismatchingDims) {
1584  mismatchingDims->first = srcDim;
1585  mismatchingDims->second = dstDim;
1586  }
1588  }
1589  }
1590 
1592 }
1593 
1594 static LogicalResult verify(BroadcastOp op) {
1595  std::pair<int, int> mismatchingDims;
1597  op.getSourceType(), op.getVectorType(), &mismatchingDims);
1598  if (res == BroadcastableToResult::Success)
1599  return success();
1601  return op.emitOpError("source rank higher than destination rank");
1603  return op.emitOpError("dimension mismatch (")
1604  << mismatchingDims.first << " vs. " << mismatchingDims.second << ")";
1606  return op.emitOpError("source type is not a vector");
1607  llvm_unreachable("unexpected vector.broadcast op error");
1608 }
1609 
1610 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
1611  if (getSourceType() == getVectorType())
1612  return source();
1613  if (!operands[0])
1614  return {};
1615  auto vectorType = getVectorType();
1616  if (operands[0].getType().isIntOrIndexOrFloat())
1617  return DenseElementsAttr::get(vectorType, operands[0]);
1618  if (auto attr = operands[0].dyn_cast<SplatElementsAttr>())
1619  return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
1620  return {};
1621 }
1622 
1623 namespace {
1624 
1625 // Fold broadcast1(broadcast2(x)) into broadcast1(x).
1626 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
1628 
1629  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
1630  PatternRewriter &rewriter) const override {
1631  auto srcBroadcast = broadcastOp.source().getDefiningOp<BroadcastOp>();
1632  if (!srcBroadcast)
1633  return failure();
1634  rewriter.replaceOpWithNewOp<BroadcastOp>(
1635  broadcastOp, broadcastOp.getVectorType(), srcBroadcast.source());
1636  return success();
1637  }
1638 };
1639 } // namespace
1640 
1641 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
1642  MLIRContext *context) {
1643  // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
1644  // calling `populateCastAwayVectorLeadingOneDimPatterns`
1645  results.add<BroadcastFolder>(context);
1646 }
1647 
1648 //===----------------------------------------------------------------------===//
1649 // ShuffleOp
1650 //===----------------------------------------------------------------------===//
1651 
1652 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
1653  Value v2, ArrayRef<int64_t> mask) {
1654  result.addOperands({v1, v2});
1655  auto maskAttr = getVectorSubscriptAttr(builder, mask);
1656  auto v1Type = v1.getType().cast<VectorType>();
1657  auto shape = llvm::to_vector<4>(v1Type.getShape());
1658  shape[0] = mask.size();
1659  result.addTypes(VectorType::get(shape, v1Type.getElementType()));
1660  result.addAttribute(getMaskAttrName(), maskAttr);
1661 }
1662 
1663 static void print(OpAsmPrinter &p, ShuffleOp op) {
1664  p << " " << op.v1() << ", " << op.v2() << " " << op.mask();
1665  p.printOptionalAttrDict(op->getAttrs(), {ShuffleOp::getMaskAttrName()});
1666  p << " : " << op.v1().getType() << ", " << op.v2().getType();
1667 }
1668 
1669 static LogicalResult verify(ShuffleOp op) {
1670  VectorType resultType = op.getVectorType();
1671  VectorType v1Type = op.getV1VectorType();
1672  VectorType v2Type = op.getV2VectorType();
1673  // Verify ranks.
1674  int64_t resRank = resultType.getRank();
1675  int64_t v1Rank = v1Type.getRank();
1676  int64_t v2Rank = v2Type.getRank();
1677  if (resRank != v1Rank || v1Rank != v2Rank)
1678  return op.emitOpError("rank mismatch");
1679  // Verify all but leading dimension sizes.
1680  for (int64_t r = 1; r < v1Rank; ++r) {
1681  int64_t resDim = resultType.getDimSize(r);
1682  int64_t v1Dim = v1Type.getDimSize(r);
1683  int64_t v2Dim = v2Type.getDimSize(r);
1684  if (resDim != v1Dim || v1Dim != v2Dim)
1685  return op.emitOpError("dimension mismatch");
1686  }
1687  // Verify mask length.
1688  auto maskAttr = op.mask().getValue();
1689  int64_t maskLength = maskAttr.size();
1690  if (maskLength != resultType.getDimSize(0))
1691  return op.emitOpError("mask length mismatch");
1692  // Verify all indices.
1693  int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0);
1694  for (const auto &en : llvm::enumerate(maskAttr)) {
1695  auto attr = en.value().dyn_cast<IntegerAttr>();
1696  if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
1697  return op.emitOpError("mask index #")
1698  << (en.index() + 1) << " out of range";
1699  }
1700  return success();
1701 }
1702 
1704  OpAsmParser::OperandType v1, v2;
1705  Attribute attr;
1706  VectorType v1Type, v2Type;
1707  if (parser.parseOperand(v1) || parser.parseComma() ||
1708  parser.parseOperand(v2) ||
1709  parser.parseAttribute(attr, ShuffleOp::getMaskAttrName(),
1710  result.attributes) ||
1711  parser.parseOptionalAttrDict(result.attributes) ||
1712  parser.parseColonType(v1Type) || parser.parseComma() ||
1713  parser.parseType(v2Type) ||
1714  parser.resolveOperand(v1, v1Type, result.operands) ||
1715  parser.resolveOperand(v2, v2Type, result.operands))
1716  return failure();
1717  // Construct resulting type: leading dimension matches mask length,
1718  // all trailing dimensions match the operands.
1719  auto maskAttr = attr.dyn_cast<ArrayAttr>();
1720  if (!maskAttr)
1721  return parser.emitError(parser.getNameLoc(), "missing mask attribute");
1722  int64_t maskLength = maskAttr.size();
1723  if (maskLength <= 0)
1724  return parser.emitError(parser.getNameLoc(), "invalid mask length");
1725  int64_t v1Rank = v1Type.getRank();
1726  SmallVector<int64_t, 4> shape;
1727  shape.reserve(v1Rank);
1728  shape.push_back(maskLength);
1729  for (int64_t r = 1; r < v1Rank; ++r)
1730  shape.push_back(v1Type.getDimSize(r));
1731  VectorType resType = VectorType::get(shape, v1Type.getElementType());
1732  parser.addTypeToList(resType, result.types);
1733  return success();
1734 }
1735 
1736 //===----------------------------------------------------------------------===//
1737 // InsertElementOp
1738 //===----------------------------------------------------------------------===//
1739 
1740 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
1741  Value source, Value dest) {
1742  result.addOperands({source, dest});
1743  result.addTypes(dest.getType());
1744 }
1745 
1746 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
1747  Value source, Value dest, Value position) {
1748  result.addOperands({source, dest, position});
1749  result.addTypes(dest.getType());
1750 }
1751 
1752 static LogicalResult verify(InsertElementOp op) {
1753  auto dstVectorType = op.getDestVectorType();
1754  if (dstVectorType.getRank() == 0) {
1755  if (op.position())
1756  return op.emitOpError("expected position to be empty with 0-D vector");
1757  return success();
1758  }
1759  if (dstVectorType.getRank() != 1)
1760  return op.emitOpError("unexpected >1 vector rank");
1761  if (!op.position())
1762  return op.emitOpError("expected position for 1-D vector");
1763  return success();
1764 }
1765 
1766 //===----------------------------------------------------------------------===//
1767 // InsertOp
1768 //===----------------------------------------------------------------------===//
1769 
1770 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
1771  Value dest, ArrayRef<int64_t> position) {
1772  result.addOperands({source, dest});
1773  auto positionAttr = getVectorSubscriptAttr(builder, position);
1774  result.addTypes(dest.getType());
1775  result.addAttribute(getPositionAttrName(), positionAttr);
1776 }
1777 
1778 // Convenience builder which assumes the values are constant indices.
1779 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
1780  Value dest, ValueRange position) {
1781  SmallVector<int64_t, 4> positionConstants =
1782  llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
1783  return pos.getDefiningOp<arith::ConstantIndexOp>().value();
1784  }));
1785  build(builder, result, source, dest, positionConstants);
1786 }
1787 
1788 static LogicalResult verify(InsertOp op) {
1789  auto positionAttr = op.position().getValue();
1790  auto destVectorType = op.getDestVectorType();
1791  if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
1792  return op.emitOpError(
1793  "expected position attribute of rank smaller than dest vector rank");
1794  auto srcVectorType = op.getSourceType().dyn_cast<VectorType>();
1795  if (srcVectorType &&
1796  (static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
1797  static_cast<unsigned>(destVectorType.getRank())))
1798  return op.emitOpError("expected position attribute rank + source rank to "
1799  "match dest vector rank");
1800  if (!srcVectorType &&
1801  (positionAttr.size() != static_cast<unsigned>(destVectorType.getRank())))
1802  return op.emitOpError(
1803  "expected position attribute rank to match the dest vector rank");
1804  for (const auto &en : llvm::enumerate(positionAttr)) {
1805  auto attr = en.value().dyn_cast<IntegerAttr>();
1806  if (!attr || attr.getInt() < 0 ||
1807  attr.getInt() >= destVectorType.getDimSize(en.index()))
1808  return op.emitOpError("expected position attribute #")
1809  << (en.index() + 1)
1810  << " to be a non-negative integer smaller than the corresponding "
1811  "dest vector dimension";
1812  }
1813  return success();
1814 }
1815 
1816 namespace {
1817 
1818 // If insertOp is only inserting unit dimensions it can be transformed to a
1819 // broadcast.
1820 class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
1821 public:
1823 
1824  LogicalResult matchAndRewrite(InsertOp insertOp,
1825  PatternRewriter &rewriter) const override {
1826  auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>();
1827  if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
1828  srcVecType.getNumElements())
1829  return failure();
1830  rewriter.replaceOpWithNewOp<BroadcastOp>(
1831  insertOp, insertOp.getDestVectorType(), insertOp.source());
1832  return success();
1833  }
1834 };
1835 
1836 } // namespace
1837 
1838 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
1839  MLIRContext *context) {
1840  results.add<InsertToBroadcast, BroadcastFolder>(context);
1841 }
1842 
1843 // Eliminates insert operations that produce values identical to their source
1844 // value. This happens when the source and destination vectors have identical
1845 // sizes.
1846 OpFoldResult vector::InsertOp::fold(ArrayRef<Attribute> operands) {
1847  if (position().empty())
1848  return source();
1849  return {};
1850 }
1851 
1852 //===----------------------------------------------------------------------===//
1853 // InsertMapOp
1854 //===----------------------------------------------------------------------===//
1855 
1856 void InsertMapOp::build(OpBuilder &builder, OperationState &result,
1857  Value vector, Value dest, ValueRange ids) {
1858  InsertMapOp::build(builder, result, dest.getType(), vector, dest, ids);
1859 }
1860 
1861 static LogicalResult verify(InsertMapOp op) {
1862  if (op.getSourceVectorType().getRank() != op.getResultType().getRank())
1863  return op.emitOpError(
1864  "expected source and destination vectors of same rank");
1865  unsigned numId = 0;
1866  for (unsigned i = 0, e = op.getResultType().getRank(); i < e; i++) {
1867  if (op.getResultType().getDimSize(i) %
1868  op.getSourceVectorType().getDimSize(i) !=
1869  0)
1870  return op.emitOpError(
1871  "destination vector size must be a multiple of source vector size");
1872  if (op.getResultType().getDimSize(i) !=
1873  op.getSourceVectorType().getDimSize(i))
1874  numId++;
1875  }
1876  if (numId != op.ids().size())
1877  return op.emitOpError("expected number of ids must match the number of "
1878  "dimensions distributed");
1879  return success();
1880 }
1881 
1882 AffineMap InsertMapOp::map() { return calculateImplicitMap(*this); }
1883 
1884 //===----------------------------------------------------------------------===//
1885 // InsertStridedSliceOp
1886 //===----------------------------------------------------------------------===//
1887 
1888 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
1889  Value source, Value dest,
1890  ArrayRef<int64_t> offsets,
1891  ArrayRef<int64_t> strides) {
1892  result.addOperands({source, dest});
1893  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
1894  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
1895  result.addTypes(dest.getType());
1896  result.addAttribute(getOffsetsAttrName(), offsetsAttr);
1897  result.addAttribute(getStridesAttrName(), stridesAttr);
1898 }
1899 
1900 // TODO: Should be moved to Tablegen Confined attributes.
1901 template <typename OpType>
1903  ArrayAttr arrayAttr,
1904  ArrayRef<int64_t> shape,
1905  StringRef attrName) {
1906  if (arrayAttr.size() > shape.size())
1907  return op.emitOpError("expected ")
1908  << attrName << " attribute of rank smaller than vector rank";
1909  return success();
1910 }
1911 
1912 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
1913 // interval. If `halfOpen` is true then the admissible interval is [min, max).
1914 // Otherwise, the admissible interval is [min, max].
1915 template <typename OpType>
1916 static LogicalResult
1917 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
1918  int64_t max, StringRef attrName,
1919  bool halfOpen = true) {
1920  for (auto attr : arrayAttr) {
1921  auto val = attr.cast<IntegerAttr>().getInt();
1922  auto upper = max;
1923  if (!halfOpen)
1924  upper += 1;
1925  if (val < min || val >= upper)
1926  return op.emitOpError("expected ") << attrName << " to be confined to ["
1927  << min << ", " << upper << ")";
1928  }
1929  return success();
1930 }
1931 
1932 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
1933 // interval. If `halfOpen` is true then the admissible interval is [min, max).
1934 // Otherwise, the admissible interval is [min, max].
1935 template <typename OpType>
1936 static LogicalResult
1937 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
1938  ArrayRef<int64_t> shape, StringRef attrName,
1939  bool halfOpen = true, int64_t min = 0) {
1940  assert(arrayAttr.size() <= shape.size());
1941  unsigned index = 0;
1942  for (auto it : llvm::zip(arrayAttr, shape)) {
1943  auto val = std::get<0>(it).cast<IntegerAttr>().getInt();
1944  auto max = std::get<1>(it);
1945  if (!halfOpen)
1946  max += 1;
1947  if (val < min || val >= max)
1948  return op.emitOpError("expected ")
1949  << attrName << " dimension " << index << " to be confined to ["
1950  << min << ", " << max << ")";
1951  ++index;
1952  }
1953  return success();
1954 }
1955 
1956 // Returns true if all integers in `arrayAttr` are in the interval [min, max}.
1957 // interval. If `halfOpen` is true then the admissible interval is [min, max).
1958 // Otherwise, the admissible interval is [min, max].
1959 template <typename OpType>
1961  OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
1962  ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
1963  bool halfOpen = true, int64_t min = 1) {
1964  assert(arrayAttr1.size() <= shape.size());
1965  assert(arrayAttr2.size() <= shape.size());
1966  unsigned index = 0;
1967  for (auto it : llvm::zip(arrayAttr1, arrayAttr2, shape)) {
1968  auto val1 = std::get<0>(it).cast<IntegerAttr>().getInt();
1969  auto val2 = std::get<1>(it).cast<IntegerAttr>().getInt();
1970  auto max = std::get<2>(it);
1971  if (!halfOpen)
1972  max += 1;
1973  if (val1 + val2 < 0 || val1 + val2 >= max)
1974  return op.emitOpError("expected sum(")
1975  << attrName1 << ", " << attrName2 << ") dimension " << index
1976  << " to be confined to [" << min << ", " << max << ")";
1977  ++index;
1978  }
1979  return success();
1980 }
1981 
1982 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
1983  MLIRContext *context) {
1984  auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
1985  return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
1986  });
1987  return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
1988 }
1989 
1990 static LogicalResult verify(InsertStridedSliceOp op) {
1991  auto sourceVectorType = op.getSourceVectorType();
1992  auto destVectorType = op.getDestVectorType();
1993  auto offsets = op.offsets();
1994  auto strides = op.strides();
1995  if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
1996  return op.emitOpError(
1997  "expected offsets of same size as destination vector rank");
1998  if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
1999  return op.emitOpError(
2000  "expected strides of same size as source vector rank");
2001  if (sourceVectorType.getRank() > destVectorType.getRank())
2002  return op.emitOpError(
2003  "expected source rank to be smaller than destination rank");
2004 
2005  auto sourceShape = sourceVectorType.getShape();
2006  auto destShape = destVectorType.getShape();
2007  SmallVector<int64_t, 4> sourceShapeAsDestShape(
2008  destShape.size() - sourceShape.size(), 0);
2009  sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
2010  auto offName = InsertStridedSliceOp::getOffsetsAttrName();
2011  auto stridesName = InsertStridedSliceOp::getStridesAttrName();
2012  if (failed(
2013  isIntegerArrayAttrConfinedToShape(op, offsets, destShape, offName)) ||
2014  failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
2015  /*halfOpen=*/false)) ||
2017  op, offsets,
2018  makeI64ArrayAttr(sourceShapeAsDestShape, op.getContext()), destShape,
2019  offName, "source vector shape",
2020  /*halfOpen=*/false, /*min=*/1)))
2021  return failure();
2022 
2023  return success();
2024 }
2025 
2026 OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) {
2027  if (getSourceVectorType() == getDestVectorType())
2028  return source();
2029  return {};
2030 }
2031 
2032 //===----------------------------------------------------------------------===//
2033 // OuterProductOp
2034 //===----------------------------------------------------------------------===//
2035 
2036 /// Build an op without mask, use the type of `acc` as the return type.
2037 void OuterProductOp::build(OpBuilder &builder, OperationState &result,
2038  Value lhs, Value rhs, Value acc) {
2039  result.addOperands({lhs, rhs, acc});
2040  result.addTypes(acc.getType());
2041 }
2042 
2043 static void print(OpAsmPrinter &p, OuterProductOp op) {
2044  p << " " << op.lhs() << ", " << op.rhs();
2045  if (!op.acc().empty()) {
2046  p << ", " << op.acc();
2047  p.printOptionalAttrDict(op->getAttrs());
2048  }
2049  p << " : " << op.lhs().getType() << ", " << op.rhs().getType();
2050 }
2051 
2053  OperationState &result) {
2054  SmallVector<OpAsmParser::OperandType, 3> operandsInfo;
2055  Type tLHS, tRHS;
2056  if (parser.parseOperandList(operandsInfo) ||
2057  parser.parseOptionalAttrDict(result.attributes) ||
2058  parser.parseColonType(tLHS) || parser.parseComma() ||
2059  parser.parseType(tRHS))
2060  return failure();
2061  if (operandsInfo.size() < 2)
2062  return parser.emitError(parser.getNameLoc(),
2063  "expected at least 2 operands");
2064  VectorType vLHS = tLHS.dyn_cast<VectorType>();
2065  VectorType vRHS = tRHS.dyn_cast<VectorType>();
2066  if (!vLHS)
2067  return parser.emitError(parser.getNameLoc(),
2068  "expected vector type for operand #1");
2069  VectorType resType =
2070  vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
2071  vLHS.getElementType())
2072  : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType());
2073 
2074  if (!result.attributes.get(OuterProductOp::getKindAttrName())) {
2075  result.attributes.append(
2076  OuterProductOp::getKindAttrName(),
2077  CombiningKindAttr::get(OuterProductOp::getDefaultKind(),
2078  result.getContext()));
2079  }
2080 
2081  return failure(
2082  parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
2083  parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
2084  (operandsInfo.size() > 2 &&
2085  parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
2086  parser.addTypeToList(resType, result.types));
2087 }
2088 
2089 static LogicalResult verify(OuterProductOp op) {
2090  Type tRHS = op.getOperandTypeRHS();
2091  VectorType vLHS = op.getOperandVectorTypeLHS(),
2092  vRHS = tRHS.dyn_cast<VectorType>(),
2093  vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType();
2094 
2095  if (vLHS.getRank() != 1)
2096  return op.emitOpError("expected 1-d vector for operand #1");
2097 
2098  if (vRHS) {
2099  // Proper OUTER operation.
2100  if (vRHS.getRank() != 1)
2101  return op.emitOpError("expected 1-d vector for operand #2");
2102  if (vRES.getRank() != 2)
2103  return op.emitOpError("expected 2-d vector result");
2104  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
2105  return op.emitOpError("expected #1 operand dim to match result dim #1");
2106  if (vRHS.getDimSize(0) != vRES.getDimSize(1))
2107  return op.emitOpError("expected #2 operand dim to match result dim #2");
2108  } else {
2109  // An AXPY operation.
2110  if (vRES.getRank() != 1)
2111  return op.emitOpError("expected 1-d vector result");
2112  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
2113  return op.emitOpError("expected #1 operand dim to match result dim #1");
2114  }
2115 
2116  if (vACC && vACC != vRES)
2117  return op.emitOpError("expected operand #3 of same type as result type");
2118 
2119  // Verify supported combining kind.
2120  if (!isSupportedCombiningKind(op.kind(), vRES.getElementType()))
2121  return op.emitOpError("unsupported outerproduct type");
2122 
2123  return success();
2124 }
2125 
2126 //===----------------------------------------------------------------------===//
2127 // ReshapeOp
2128 //===----------------------------------------------------------------------===//
2129 
2130 static LogicalResult verify(ReshapeOp op) {
2131  // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank.
2132  auto inputVectorType = op.getInputVectorType();
2133  auto outputVectorType = op.getOutputVectorType();
2134  int64_t inputShapeRank = op.getNumInputShapeSizes();
2135  int64_t outputShapeRank = op.getNumOutputShapeSizes();
2136  SmallVector<int64_t, 4> fixedVectorSizes;
2137  op.getFixedVectorSizes(fixedVectorSizes);
2138  int64_t numFixedVectorSizes = fixedVectorSizes.size();
2139 
2140  if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
2141  return op.emitError("invalid input shape for vector type ")
2142  << inputVectorType;
2143 
2144  if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
2145  return op.emitError("invalid output shape for vector type ")
2146  << outputVectorType;
2147 
2148  // Verify that the 'fixedVectorSizes' match an input/output vector shape
2149  // suffix.
2150  unsigned inputVectorRank = inputVectorType.getRank();
2151  for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
2152  unsigned index = inputVectorRank - numFixedVectorSizes - i;
2153  if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
2154  return op.emitError("fixed vector size must match input vector for dim ")
2155  << i;
2156  }
2157 
2158  unsigned outputVectorRank = outputVectorType.getRank();
2159  for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
2160  unsigned index = outputVectorRank - numFixedVectorSizes - i;
2161  if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
2162  return op.emitError("fixed vector size must match output vector for dim ")
2163  << i;
2164  }
2165 
2166  // If all shape operands are produced by constant ops, verify that product
2167  // of dimensions for input/output shape match.
2168  auto isDefByConstant = [](Value operand) {
2169  return isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp());
2170  };
2171  if (llvm::all_of(op.input_shape(), isDefByConstant) &&
2172  llvm::all_of(op.output_shape(), isDefByConstant)) {
2173  int64_t numInputElements = 1;
2174  for (auto operand : op.input_shape())
2175  numInputElements *=
2176  cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value();
2177  int64_t numOutputElements = 1;
2178  for (auto operand : op.output_shape())
2179  numOutputElements *=
2180  cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value();
2181  if (numInputElements != numOutputElements)
2182  return op.emitError("product of input and output shape sizes must match");
2183  }
2184  return success();
2185 }
2186 
2187 void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
2188  populateFromInt64AttrArray(fixed_vector_sizes(), results);
2189 }
2190 
2191 //===----------------------------------------------------------------------===//
2192 // ExtractStridedSliceOp
2193 //===----------------------------------------------------------------------===//
2194 
2195 // Inference works as follows:
2196 // 1. Add 'sizes' from prefix of dims in 'offsets'.
2197 // 2. Add sizes from 'vectorType' for remaining dims.
2199  ArrayAttr offsets, ArrayAttr sizes,
2200  ArrayAttr strides) {
2201  assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
2202  SmallVector<int64_t, 4> shape;
2203  shape.reserve(vectorType.getRank());
2204  unsigned idx = 0;
2205  for (unsigned e = offsets.size(); idx < e; ++idx)
2206  shape.push_back(sizes[idx].cast<IntegerAttr>().getInt());
2207  for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
2208  shape.push_back(vectorType.getShape()[idx]);
2209 
2210  return VectorType::get(shape, vectorType.getElementType());
2211 }
2212 
2213 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
2214  Value source, ArrayRef<int64_t> offsets,
2215  ArrayRef<int64_t> sizes,
2216  ArrayRef<int64_t> strides) {
2217  result.addOperands(source);
2218  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
2219  auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
2220  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
2221  result.addTypes(
2222  inferStridedSliceOpResultType(source.getType().cast<VectorType>(),
2223  offsetsAttr, sizesAttr, stridesAttr));
2224  result.addAttribute(getOffsetsAttrName(), offsetsAttr);
2225  result.addAttribute(getSizesAttrName(), sizesAttr);
2226  result.addAttribute(getStridesAttrName(), stridesAttr);
2227 }
2228 
2229 static LogicalResult verify(ExtractStridedSliceOp op) {
2230  auto type = op.getVectorType();
2231  auto offsets = op.offsets();
2232  auto sizes = op.sizes();
2233  auto strides = op.strides();
2234  if (offsets.size() != sizes.size() || offsets.size() != strides.size()) {
2235  op.emitOpError(
2236  "expected offsets, sizes and strides attributes of same size");
2237  return failure();
2238  }
2239 
2240  auto shape = type.getShape();
2241  auto offName = ExtractStridedSliceOp::getOffsetsAttrName();
2242  auto sizesName = ExtractStridedSliceOp::getSizesAttrName();
2243  auto stridesName = ExtractStridedSliceOp::getStridesAttrName();
2244  if (failed(isIntegerArrayAttrSmallerThanShape(op, offsets, shape, offName)) ||
2245  failed(isIntegerArrayAttrSmallerThanShape(op, sizes, shape, sizesName)) ||
2246  failed(isIntegerArrayAttrSmallerThanShape(op, strides, shape,
2247  stridesName)) ||
2248  failed(isIntegerArrayAttrConfinedToShape(op, offsets, shape, offName)) ||
2249  failed(isIntegerArrayAttrConfinedToShape(op, sizes, shape, sizesName,
2250  /*halfOpen=*/false,
2251  /*min=*/1)) ||
2252  failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
2253  /*halfOpen=*/false)) ||
2254  failed(isSumOfIntegerArrayAttrConfinedToShape(op, offsets, sizes, shape,
2255  offName, sizesName,
2256  /*halfOpen=*/false)))
2257  return failure();
2258 
2259  auto resultType = inferStridedSliceOpResultType(
2260  op.getVectorType(), op.offsets(), op.sizes(), op.strides());
2261  if (op.getResult().getType() != resultType) {
2262  op.emitOpError("expected result type to be ") << resultType;
2263  return failure();
2264  }
2265 
2266  return success();
2267 }
2268 
2269 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
2270 // to use the source of the InsertStrided ops if we can detect that the
2271 // extracted vector is a subset of one of the vector inserted.
2272 static LogicalResult
2273 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
2274  // Helper to extract integer out of ArrayAttr.
2275  auto getElement = [](ArrayAttr array, int idx) {
2276  return array[idx].cast<IntegerAttr>().getInt();
2277  };
2278  ArrayAttr extractOffsets = op.offsets();
2279  ArrayAttr extractStrides = op.strides();
2280  ArrayAttr extractSizes = op.sizes();
2281  auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>();
2282  while (insertOp) {
2283  if (op.getVectorType().getRank() !=
2284  insertOp.getSourceVectorType().getRank())
2285  return failure();
2286  ArrayAttr insertOffsets = insertOp.offsets();
2287  ArrayAttr insertStrides = insertOp.strides();
2288  // If the rank of extract is greater than the rank of insert, we are likely
2289  // extracting a partial chunk of the vector inserted.
2290  if (extractOffsets.size() > insertOffsets.size())
2291  return failure();
2292  bool patialoverlap = false;
2293  bool disjoint = false;
2294  SmallVector<int64_t, 4> offsetDiffs;
2295  for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
2296  if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
2297  return failure();
2298  int64_t start = getElement(insertOffsets, dim);
2299  int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
2300  int64_t offset = getElement(extractOffsets, dim);
2301  int64_t size = getElement(extractSizes, dim);
2302  // Check if the start of the extract offset is in the interval inserted.
2303  if (start <= offset && offset < end) {
2304  // If the extract interval overlaps but is not fully included we may
2305  // have a partial overlap that will prevent any folding.
2306  if (offset + size > end)
2307  patialoverlap = true;
2308  offsetDiffs.push_back(offset - start);
2309  continue;
2310  }
2311  disjoint = true;
2312  break;
2313  }
2314  // The extract element chunk is a subset of the insert element.
2315  if (!disjoint && !patialoverlap) {
2316  op.setOperand(insertOp.source());
2317  // OpBuilder is only used as a helper to build an I64ArrayAttr.
2318  OpBuilder b(op.getContext());
2319  op->setAttr(ExtractStridedSliceOp::getOffsetsAttrName(),
2320  b.getI64ArrayAttr(offsetDiffs));
2321  return success();
2322  }
2323  // If the chunk extracted is disjoint from the chunk inserted, keep looking
2324  // in the insert chain.
2325  if (disjoint)
2326  insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>();
2327  else {
2328  // The extracted vector partially overlap the inserted vector, we cannot
2329  // fold.
2330  return failure();
2331  }
2332  }
2333  return failure();
2334 }
2335 
2336 OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) {
2337  if (getVectorType() == getResult().getType())
2338  return vector();
2340  return getResult();
2341  return {};
2342 }
2343 
2344 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
2345  populateFromInt64AttrArray(offsets(), results);
2346 }
2347 
2348 namespace {
2349 
2350 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
2351 // ConstantMaskOp.
2352 class StridedSliceConstantMaskFolder final
2353  : public OpRewritePattern<ExtractStridedSliceOp> {
2354 public:
2356 
2357  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
2358  PatternRewriter &rewriter) const override {
2359  // Return if 'extractStridedSliceOp' operand is not defined by a
2360  // ConstantMaskOp.
2361  auto *defOp = extractStridedSliceOp.vector().getDefiningOp();
2362  auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
2363  if (!constantMaskOp)
2364  return failure();
2365  // Return if 'extractStridedSliceOp' has non-unit strides.
2366  if (extractStridedSliceOp.hasNonUnitStrides())
2367  return failure();
2368  // Gather constant mask dimension sizes.
2369  SmallVector<int64_t, 4> maskDimSizes;
2370  populateFromInt64AttrArray(constantMaskOp.mask_dim_sizes(), maskDimSizes);
2371  // Gather strided slice offsets and sizes.
2372  SmallVector<int64_t, 4> sliceOffsets;
2373  populateFromInt64AttrArray(extractStridedSliceOp.offsets(), sliceOffsets);
2374  SmallVector<int64_t, 4> sliceSizes;
2375  populateFromInt64AttrArray(extractStridedSliceOp.sizes(), sliceSizes);
2376 
2377  // Compute slice of vector mask region.
2378  SmallVector<int64_t, 4> sliceMaskDimSizes;
2379  assert(sliceOffsets.size() == maskDimSizes.size());
2380  for (auto it : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
2381  int64_t maskDimSize = std::get<0>(it);
2382  int64_t sliceOffset = std::get<1>(it);
2383  int64_t sliceSize = std::get<2>(it);
2384  int64_t sliceMaskDimSize = std::max(
2385  static_cast<int64_t>(0),
2386  std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
2387  sliceMaskDimSizes.push_back(sliceMaskDimSize);
2388  }
2389  // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
2390  // region is a conjunction of mask dim intervals).
2391  if (llvm::is_contained(sliceMaskDimSizes, 0))
2392  sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
2393 
2394  // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
2395  // region.
2396  rewriter.replaceOpWithNewOp<ConstantMaskOp>(
2397  extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
2398  vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
2399  return success();
2400  }
2401 };
2402 
2403 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
2404 class StridedSliceConstantFolder final
2405  : public OpRewritePattern<ExtractStridedSliceOp> {
2406 public:
2408 
2409  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
2410  PatternRewriter &rewriter) const override {
2411  // Return if 'extractStridedSliceOp' operand is not defined by a
2412  // ConstantOp.
2413  auto constantOp =
2414  extractStridedSliceOp.vector().getDefiningOp<arith::ConstantOp>();
2415  if (!constantOp)
2416  return failure();
2417  auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
2418  if (!dense)
2419  return failure();
2420  auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(),
2421  dense.getSplatValue<Attribute>());
2422  rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
2423  newAttr);
2424  return success();
2425  }
2426 };
2427 
2428 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
2429 // BroadcastOp(ExtractStrideSliceOp).
2430 class StridedSliceBroadcast final
2431  : public OpRewritePattern<ExtractStridedSliceOp> {
2432 public:
2434 
2435  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
2436  PatternRewriter &rewriter) const override {
2437  auto broadcast = op.vector().getDefiningOp<BroadcastOp>();
2438  if (!broadcast)
2439  return failure();
2440  auto srcVecType = broadcast.source().getType().dyn_cast<VectorType>();
2441  unsigned srcRrank = srcVecType ? srcVecType.getRank() : 0;
2442  auto dstVecType = op.getType().cast<VectorType>();
2443  unsigned dstRank = dstVecType.getRank();
2444  unsigned rankDiff = dstRank - srcRrank;
2445  // Check if the most inner dimensions of the source of the broadcast are the
2446  // same as the destination of the extract. If this is the case we can just
2447  // use a broadcast as the original dimensions are untouched.
2448  bool lowerDimMatch = true;
2449  for (unsigned i = 0; i < srcRrank; i++) {
2450  if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
2451  lowerDimMatch = false;
2452  break;
2453  }
2454  }
2455  Value source = broadcast.source();
2456  if (!lowerDimMatch) {
2457  // The inner dimensions don't match, it means we need to extract from the
2458  // source of the orignal broadcast and then broadcast the extracted value.
2459  source = rewriter.create<ExtractStridedSliceOp>(
2460  op->getLoc(), source,
2461  getI64SubArray(op.offsets(), /* dropFront=*/rankDiff),
2462  getI64SubArray(op.sizes(), /* dropFront=*/rankDiff),
2463  getI64SubArray(op.strides(), /* dropFront=*/rankDiff));
2464  }
2465  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
2466  return success();
2467  }
2468 };
2469 
2470 /// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
2471 class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
2472 public:
2474 
2475  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
2476  PatternRewriter &rewriter) const override {
2477  auto splat = op.vector().getDefiningOp<SplatOp>();
2478  if (!splat)
2479  return failure();
2480  rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
2481  return success();
2482  }
2483 };
2484 
2485 } // namespace
2486 
2487 void ExtractStridedSliceOp::getCanonicalizationPatterns(
2488  RewritePatternSet &results, MLIRContext *context) {
2489  // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
2490  // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
2491  results.add<StridedSliceConstantMaskFolder, StridedSliceConstantFolder,
2492  StridedSliceBroadcast, StridedSliceSplat>(context);
2493 }
2494 
2495 //===----------------------------------------------------------------------===//
2496 // TransferReadOp
2497 //===----------------------------------------------------------------------===//
2498 
2499 /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
2500 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2501  VectorType vectorType, Value source,
2502  ValueRange indices, AffineMapAttr permutationMapAttr,
2503  /*optional*/ ArrayAttr inBoundsAttr) {
2504  Type elemType = source.getType().cast<ShapedType>().getElementType();
2505  Value padding = builder.create<arith::ConstantOp>(
2506  result.location, elemType, builder.getZeroAttr(elemType));
2507  build(builder, result, vectorType, source, indices, permutationMapAttr,
2508  padding, /*mask=*/Value(), inBoundsAttr);
2509 }
2510 
2511 /// 2. Builder that sets padding to zero an empty mask (variant without attrs).
2512 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2513  VectorType vectorType, Value source,
2514  ValueRange indices, AffineMap permutationMap,
2515  Optional<ArrayRef<bool>> inBounds) {
2516  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
2517  auto inBoundsAttr = (inBounds && !inBounds.getValue().empty())
2518  ? builder.getBoolArrayAttr(inBounds.getValue())
2519  : ArrayAttr();
2520  build(builder, result, vectorType, source, indices, permutationMapAttr,
2521  inBoundsAttr);
2522 }
2523 
2524 /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
2525 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2526  VectorType vectorType, Value source,
2527  ValueRange indices, Value padding,
2528  Optional<ArrayRef<bool>> inBounds) {
2529  AffineMap permutationMap = getTransferMinorIdentityMap(
2530  source.getType().cast<ShapedType>(), vectorType);
2531  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
2532  auto inBoundsAttr = (inBounds && !inBounds.getValue().empty())
2533  ? builder.getBoolArrayAttr(inBounds.getValue())
2534  : ArrayAttr();
2535  build(builder, result, vectorType, source, indices, permutationMapAttr,
2536  padding,
2537  /*mask=*/Value(), inBoundsAttr);
2538 }
2539 
2540 /// 4. Builder that sets padding to zero and permutation map to
2541 /// 'getMinorIdentityMap'.
2542 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2543  VectorType vectorType, Value source,
2544  ValueRange indices,
2545  Optional<ArrayRef<bool>> inBounds) {
2546  Type elemType = source.getType().cast<ShapedType>().getElementType();
2547  Value padding = builder.create<arith::ConstantOp>(
2548  result.location, elemType, builder.getZeroAttr(elemType));
2549  build(builder, result, vectorType, source, indices, padding, inBounds);
2550 }
2551 
2552 template <typename EmitFun>
2554  EmitFun emitOpError) {
2555  SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
2556  for (auto expr : permutationMap.getResults()) {
2557  auto dim = expr.dyn_cast<AffineDimExpr>();
2558  auto zero = expr.dyn_cast<AffineConstantExpr>();
2559  if (zero) {
2560  if (zero.getValue() != 0) {
2561  return emitOpError(
2562  "requires a projected permutation_map (at most one dim or the zero "
2563  "constant can appear in each result)");
2564  }
2565  continue;
2566  }
2567  if (!dim) {
2568  return emitOpError("requires a projected permutation_map (at most one "
2569  "dim or the zero constant can appear in each result)");
2570  }
2571  if (seen[dim.getPosition()]) {
2572  return emitOpError(
2573  "requires a permutation_map that is a permutation (found one dim "
2574  "used more than once)");
2575  }
2576  seen[dim.getPosition()] = true;
2577  }
2578  return success();
2579 }
2580 
2581 static LogicalResult
2582 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
2583  VectorType vectorType, VectorType maskType,
2584  AffineMap permutationMap, ArrayAttr inBounds) {
2585  if (op->hasAttr("masked")) {
2586  return op->emitOpError("masked attribute has been removed. "
2587  "Use in_bounds instead.");
2588  }
2589 
2590  if (!shapedType.isa<MemRefType, RankedTensorType>())
2591  return op->emitOpError(
2592  "requires source to be a memref or ranked tensor type");
2593 
2594  auto elementType = shapedType.getElementType();
2595  DataLayout dataLayout = DataLayout::closest(op);
2596  if (auto vectorElementType = elementType.dyn_cast<VectorType>()) {
2597  // Memref or tensor has vector element type.
2598  unsigned sourceVecSize =
2599  dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
2600  vectorElementType.getShape().back();
2601  unsigned resultVecSize =
2602  dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
2603  vectorType.getShape().back();
2604  if (resultVecSize % sourceVecSize != 0)
2605  return op->emitOpError(
2606  "requires the bitwidth of the minor 1-D vector to be an integral "
2607  "multiple of the bitwidth of the minor 1-D vector of the source");
2608 
2609  unsigned sourceVecEltRank = vectorElementType.getRank();
2610  unsigned resultVecRank = vectorType.getRank();
2611  if (sourceVecEltRank > resultVecRank)
2612  return op->emitOpError(
2613  "requires source vector element and vector result ranks to match.");
2614  unsigned rankOffset = resultVecRank - sourceVecEltRank;
2615  // Check that permutation map results match 'rankOffset' of vector type.
2616  if (permutationMap.getNumResults() != rankOffset)
2617  return op->emitOpError("requires a permutation_map with result dims of "
2618  "the same rank as the vector type");
2619 
2620  if (maskType)
2621  return op->emitOpError("does not support masks with vector element type");
2622  } else {
2623  // Memref or tensor has scalar element type.
2624  unsigned minorSize =
2625  vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
2626  unsigned resultVecSize =
2627  dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
2628  if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
2629  return op->emitOpError(
2630  "requires the bitwidth of the minor 1-D vector to be an integral "
2631  "multiple of the bitwidth of the source element type");
2632 
2633  // Check that permutation map results match rank of vector type.
2634  if (permutationMap.getNumResults() != vectorType.getRank())
2635  return op->emitOpError("requires a permutation_map with result dims of "
2636  "the same rank as the vector type");
2637 
2638  VectorType expectedMaskType =
2639  vector::detail::transferMaskType(vectorType, permutationMap);
2640  if (maskType && expectedMaskType != maskType)
2641  return op->emitOpError("expects mask type consistent with permutation "
2642  "map: ")
2643  << maskType;
2644  }
2645 
2646  if (permutationMap.getNumSymbols() != 0)
2647  return op->emitOpError("requires permutation_map without symbols");
2648 
2649  if (permutationMap.getNumInputs() != shapedType.getRank())
2650  return op->emitOpError("requires a permutation_map with input dims of the "
2651  "same rank as the source type");
2652 
2653  if (inBounds) {
2654  if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
2655  return op->emitOpError("expects the optional in_bounds attr of same rank "
2656  "as permutation_map results: ")
2657  << AffineMapAttr::get(permutationMap)
2658  << " vs inBounds of size: " << inBounds.size();
2659  for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i)
2660  if (permutationMap.getResult(i).isa<AffineConstantExpr>() &&
2661  !inBounds.getValue()[i].cast<BoolAttr>().getValue())
2662  return op->emitOpError("requires broadcast dimensions to be in-bounds");
2663  }
2664 
2665  return success();
2666 }
2667 
2668 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
2669  SmallVector<StringRef, 3> elidedAttrs;
2670  elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
2671  if (op.permutation_map().isMinorIdentity())
2672  elidedAttrs.push_back(op.getPermutationMapAttrName());
2673  bool elideInBounds = true;
2674  if (auto inBounds = op.in_bounds()) {
2675  for (auto attr : *inBounds) {
2676  if (attr.template cast<BoolAttr>().getValue()) {
2677  elideInBounds = false;
2678  break;
2679  }
2680  }
2681  }
2682  if (elideInBounds)
2683  elidedAttrs.push_back(op.getInBoundsAttrName());
2684  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
2685 }
2686 
2687 static void print(OpAsmPrinter &p, TransferReadOp op) {
2688  p << " " << op.source() << "[" << op.indices() << "], " << op.padding();
2689  if (op.mask())
2690  p << ", " << op.mask();
2691  printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
2692  p << " : " << op.getShapedType() << ", " << op.getVectorType();
2693 }
2694 
2696  OperationState &result) {
2697  auto &builder = parser.getBuilder();
2698  llvm::SMLoc typesLoc;
2699  OpAsmParser::OperandType sourceInfo;
2700  SmallVector<OpAsmParser::OperandType, 8> indexInfo;
2701  OpAsmParser::OperandType paddingInfo;
2702  SmallVector<Type, 2> types;
2703  OpAsmParser::OperandType maskInfo;
2704  // Parsing with support for paddingValue.
2705  if (parser.parseOperand(sourceInfo) ||
2706  parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
2707  parser.parseComma() || parser.parseOperand(paddingInfo))
2708  return failure();
2709  ParseResult hasMask = parser.parseOptionalComma();
2710  if (hasMask.succeeded()) {
2711  parser.parseOperand(maskInfo);
2712  }
2713  if (parser.parseOptionalAttrDict(result.attributes) ||
2714  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
2715  return failure();
2716  if (types.size() != 2)
2717  return parser.emitError(typesLoc, "requires two types");
2718  auto indexType = builder.getIndexType();
2719  auto shapedType = types[0].dyn_cast<ShapedType>();
2720  if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
2721  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
2722  VectorType vectorType = types[1].dyn_cast<VectorType>();
2723  if (!vectorType)
2724  return parser.emitError(typesLoc, "requires vector type");
2725  auto permutationAttrName = TransferReadOp::getPermutationMapAttrName();
2726  Attribute mapAttr = result.attributes.get(permutationAttrName);
2727  if (!mapAttr) {
2728  auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
2729  // Update `mapAttr` that is used later to determine mask type.
2730  mapAttr = AffineMapAttr::get(permMap);
2731  result.attributes.set(permutationAttrName, mapAttr);
2732  }
2733  if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
2734  parser.resolveOperands(indexInfo, indexType, result.operands) ||
2735  parser.resolveOperand(paddingInfo, shapedType.getElementType(),
2736  result.operands))
2737  return failure();
2738  if (hasMask.succeeded()) {
2739  if (shapedType.getElementType().dyn_cast<VectorType>())
2740  return parser.emitError(
2741  maskInfo.location, "does not support masks with vector element type");
2742  auto map = mapAttr.dyn_cast<AffineMapAttr>().getValue();
2743  // Instead of adding the mask type as an op type, compute it based on the
2744  // vector type and the permutation map (to keep the type signature small).
2745  auto maskType = mlir::vector::detail::transferMaskType(vectorType, map);
2746  if (parser.resolveOperand(maskInfo, maskType, result.operands))
2747  return failure();
2748  }
2749  result.addAttribute(
2750  TransferReadOp::getOperandSegmentSizeAttr(),
2751  builder.getI32VectorAttr({1, static_cast<int32_t>(indexInfo.size()), 1,
2752  static_cast<int32_t>(hasMask.succeeded())}));
2753  return parser.addTypeToList(vectorType, result.types);
2754 }
2755 
2756 static LogicalResult verify(TransferReadOp op) {
2757  // Consistency of elemental types in source and vector.
2758  ShapedType shapedType = op.getShapedType();
2759  VectorType vectorType = op.getVectorType();
2760  VectorType maskType = op.getMaskType();
2761  auto paddingType = op.padding().getType();
2762  auto permutationMap = op.permutation_map();
2763  auto sourceElementType = shapedType.getElementType();
2764 
2765  if (static_cast<int64_t>(op.indices().size()) != shapedType.getRank())
2766  return op.emitOpError("requires ") << shapedType.getRank() << " indices";
2767 
2768  if (failed(
2769  verifyTransferOp(cast<VectorTransferOpInterface>(op.getOperation()),
2770  shapedType, vectorType, maskType, permutationMap,
2771  op.in_bounds() ? *op.in_bounds() : ArrayAttr())))
2772  return failure();
2773 
2774  if (auto sourceVectorElementType = sourceElementType.dyn_cast<VectorType>()) {
2775  // Source has vector element type.
2776  // Check that 'sourceVectorElementType' and 'paddingType' types match.
2777  if (sourceVectorElementType != paddingType)
2778  return op.emitOpError(
2779  "requires source element type and padding type to match.");
2780 
2781  } else {
2782  // Check that 'paddingType' is valid to store in a vector type.
2783  if (!VectorType::isValidElementType(paddingType))
2784  return op.emitOpError("requires valid padding vector elemental type");
2785 
2786  // Check that padding type and vector element types match.
2787  if (paddingType != sourceElementType)
2788  return op.emitOpError(
2789  "requires formal padding and source of the same elemental type");
2790  }
2791 
2792  return verifyPermutationMap(permutationMap,
2793  [&op](Twine t) { return op.emitOpError(t); });
2794 }
2795 
2796 /// This is a common class used for patterns of the form
2797 /// ```
2798 /// someop(memrefcast) -> someop
2799 /// ```
2800 /// It folds the source of the memref.cast into the root operation directly.
2802  bool folded = false;
2803  for (OpOperand &operand : op->getOpOperands()) {
2804  auto castOp = operand.get().getDefiningOp<memref::CastOp>();
2805  if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
2806  operand.set(castOp.getOperand());
2807  folded = true;
2808  }
2809  }
2810  return success(folded);
2811 }
2812 
2814  bool folded = false;
2815  for (OpOperand &operand : op->getOpOperands()) {
2816  auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
2817  if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
2818  operand.set(castOp.getOperand());
2819  folded = true;
2820  }
2821  }
2822  return success(folded);
2823 }
2824 
2825 template <typename TransferOp>
2826 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
2827  // TODO: support more aggressive createOrFold on:
2828  // `op.indices()[indicesIdx] + vectorType < dim(op.source(), indicesIdx)`
2829  if (op.getShapedType().isDynamicDim(indicesIdx))
2830  return false;
2831  Value index = op.indices()[indicesIdx];
2832  auto cstOp = index.getDefiningOp<arith::ConstantIndexOp>();
2833  if (!cstOp)
2834  return false;
2835 
2836  int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
2837  int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
2838 
2839  return cstOp.value() + vectorSize <= sourceSize;
2840 }
2841 
2842 template <typename TransferOp>
2844  // TODO: support 0-d corner case.
2845  // TODO: Be less conservative.
2846  if (op.getTransferRank() == 0)
2847  return failure();
2848  AffineMap permutationMap = op.permutation_map();
2849  bool changed = false;
2850  SmallVector<bool, 4> newInBounds;
2851  newInBounds.reserve(op.getTransferRank());
2852  for (unsigned i = 0; i < op.getTransferRank(); ++i) {
2853  // Already marked as in-bounds, nothing to see here.
2854  if (op.isDimInBounds(i)) {
2855  newInBounds.push_back(true);
2856  continue;
2857  }
2858  // Currently out-of-bounds, check whether we can statically determine it is
2859  // inBounds.
2860  auto dimExpr = permutationMap.getResult(i).dyn_cast<AffineDimExpr>();
2861  assert(dimExpr && "Broadcast dims must be in-bounds");
2862  auto inBounds =
2863  isInBounds(op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition());
2864  newInBounds.push_back(inBounds);
2865  // We commit the pattern if it is "more inbounds".
2866  changed |= inBounds;
2867  }
2868  if (!changed)
2869  return failure();
2870  // OpBuilder is only used as a helper to build an I64ArrayAttr.
2871  OpBuilder b(op.getContext());
2872  op->setAttr(TransferOp::getInBoundsAttrName(),
2873  b.getBoolArrayAttr(newInBounds));
2874  return success();
2875 }
2876 
2877 /// ```
2878 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
2879 /// : vector<1x4xf32>, tensor<4x4xf32>
2880 /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
2881 /// : tensor<4x4xf32>, vector<1x4xf32>
2882 /// ```
2883 /// -> Folds into
2884 /// ```
2885 /// %v0
2886 /// ```
2887 static Value foldRAW(TransferReadOp readOp) {
2888  if (!readOp.getShapedType().isa<RankedTensorType>())
2889  return {};
2890  auto defWrite = readOp.source().getDefiningOp<vector::TransferWriteOp>();
2891  while (defWrite) {
2892  if (checkSameValueRAW(defWrite, readOp))
2893  return defWrite.vector();
2894  if (!isDisjointTransferIndices(
2895  cast<VectorTransferOpInterface>(defWrite.getOperation()),
2896  cast<VectorTransferOpInterface>(readOp.getOperation())))
2897  break;
2898  defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
2899  }
2900  return {};
2901 }
2902 
2903 OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
2904  if (Value vec = foldRAW(*this))
2905  return vec;
2906  /// transfer_read(memrefcast) -> transfer_read
2908  return getResult();
2909  if (succeeded(foldMemRefCast(*this)))
2910  return getResult();
2911  if (succeeded(foldTensorCast(*this)))
2912  return getResult();
2913  return OpFoldResult();
2914 }
2915 
2916 Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
2917  return llvm::to_vector<4>(getVectorType().getShape());
2918 }
2919 
2920 void TransferReadOp::getEffects(
2922  &effects) {
2923  if (getShapedType().isa<MemRefType>())
2924  effects.emplace_back(MemoryEffects::Read::get(), source(),
2926 }
2927 
2928 namespace {
2929 /// Fold transfer_reads of a tensor.extract_slice op. E.g.:
2930 ///
2931 /// ```
2932 /// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1]
2933 /// : tensor<?x?xf32> to tensor<?x?xf32>
2934 /// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]}
2935 /// : tensor<?x?xf32>, vector<4x5xf32>
2936 /// ```
2937 /// is rewritten to:
2938 /// ```
2939 /// %p0 = arith.addi %a, %e : index
2940 /// %p1 = arith.addi %b, %f : index
2941 /// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]}
2942 /// : tensor<?x?xf32>, vector<4x5xf32>
2943 /// ```
2944 struct FoldExtractSliceIntoTransferRead
2945  : public OpRewritePattern<TransferReadOp> {
2946 public:
2948 
2949  LogicalResult matchAndRewrite(TransferReadOp xferOp,
2950  PatternRewriter &rewriter) const override {
2951  // TODO: support 0-d corner case.
2952  if (xferOp.getTransferRank() == 0)
2953  return failure();
2954  if (xferOp.hasOutOfBoundsDim())
2955  return failure();
2956  if (!xferOp.permutation_map().isIdentity())
2957  return failure();
2958  if (xferOp.mask())
2959  return failure();
2960  auto extractOp = xferOp.source().getDefiningOp<tensor::ExtractSliceOp>();
2961  if (!extractOp)
2962  return failure();
2963  if (!extractOp.hasUnitStride())
2964  return failure();
2965 
2966  // Bail on illegal rank-reduction: we need to check that the rank-reduced
2967  // dims are exactly the leading dims. I.e. the following is illegal:
2968  // ```
2969  // %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] :
2970  // tensor<2x1x4xf32> to tensor<2x4xf32>
2971  // %1 = vector.transfer_read %0[0,0], %cst :
2972  // tensor<2x4xf32>, vector<2x4xf32>
2973  // ```
2974  //
2975  // Cannot fold into:
2976  // ```
2977  // %0 = vector.transfer_read %t[0,0,0], %cst :
2978  // tensor<2x1x4xf32>, vector<2x4xf32>
2979  // ```
2980  // For this, check the trailing `vectorRank` dims of the extract_slice
2981  // result tensor match the trailing dims of the inferred result tensor.
2982  int64_t rankReduced =
2983  extractOp.getSourceType().getRank() - extractOp.getType().getRank();
2984  int64_t vectorRank = xferOp.getVectorType().getRank();
2985  RankedTensorType inferredDestTensorType =
2986  tensor::ExtractSliceOp::inferResultType(
2987  extractOp.getSourceType(), extractOp.getMixedOffsets(),
2988  extractOp.getMixedSizes(), extractOp.getMixedStrides());
2989  auto actualDestTensorShape = extractOp.getType().getShape();
2990  if (rankReduced > 0 &&
2991  actualDestTensorShape.take_back(vectorRank) !=
2992  inferredDestTensorType.getShape().take_back(vectorRank))
2993  return failure();
2994 
2995  SmallVector<Value> newIndices;
2996  // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced
2997  // indices first.
2998  for (int64_t i = 0; i < rankReduced; ++i) {
2999  OpFoldResult offset = extractOp.getMixedOffsets()[i];
3000  newIndices.push_back(getValueOrCreateConstantIndexOp(
3001  rewriter, extractOp.getLoc(), offset));
3002  }
3003  for (const auto &it : llvm::enumerate(xferOp.indices())) {
3004  OpFoldResult offset =
3005  extractOp.getMixedOffsets()[it.index() + rankReduced];
3006  newIndices.push_back(rewriter.create<arith::AddIOp>(
3007  xferOp->getLoc(), it.value(),
3008  getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(),
3009  offset)));
3010  }
3011  SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
3012  rewriter.replaceOpWithNewOp<TransferReadOp>(
3013  xferOp, xferOp.getVectorType(), extractOp.source(), newIndices,
3014  xferOp.padding(), ArrayRef<bool>{inBounds});
3015 
3016  return success();
3017  }
3018 };
3019 } // namespace
3020 
3021 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3022  MLIRContext *context) {
3023  results.add<FoldExtractSliceIntoTransferRead>(context);
3024 }
3025 
3026 //===----------------------------------------------------------------------===//
3027 // TransferWriteOp
3028 //===----------------------------------------------------------------------===//
3029 
3030 /// 1. Builder with type inference.
3031 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
3032  Value vector, Value dest, ValueRange indices,
3033  AffineMapAttr permutationMapAttr,
3034  /*optional*/ Value mask,
3035  /*optional*/ ArrayAttr inBoundsAttr) {
3036  Type resultType = dest.getType().dyn_cast<RankedTensorType>();
3037  build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
3038  mask, inBoundsAttr);
3039 }
3040 
3041 /// 2. Builder with type inference that sets an empty mask (variant with attrs).
3042 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
3043  Value vector, Value dest, ValueRange indices,
3044  AffineMapAttr permutationMapAttr,
3045  /*optional*/ ArrayAttr inBoundsAttr) {
3046  build(builder, result, vector, dest, indices, permutationMapAttr,
3047  /*mask=*/Value(), inBoundsAttr);
3048 }
3049 
3050 /// 3. Builder with type inference that sets an empty mask (variant without
3051 /// attrs)
3052 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
3053  Value vector, Value dest, ValueRange indices,
3054  AffineMap permutationMap,
3055  Optional<ArrayRef<bool>> inBounds) {
3056  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
3057  auto inBoundsAttr = (inBounds && !inBounds.getValue().empty())
3058  ? builder.getBoolArrayAttr(inBounds.getValue())
3059  : ArrayAttr();
3060  build(builder, result, vector, dest, indices, permutationMapAttr,
3061  /*mask=*/Value(), inBoundsAttr);
3062 }
3063 
3064 /// 4. Builder with type inference that sets an empty mask and sets permutation
3065 /// map to 'getMinorIdentityMap'.
3066 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
3067  Value vector, Value dest, ValueRange indices,
3068  Optional<ArrayRef<bool>> inBounds) {
3069  auto vectorType = vector.getType().cast<VectorType>();
3070  AffineMap permutationMap = getTransferMinorIdentityMap(
3071  dest.getType().cast<ShapedType>(), vectorType);
3072  build(builder, result, vector, dest, indices, permutationMap, inBounds);
3073 }
3074 
3075 static ParseResult parseTransferWriteOp(OpAsmParser &parser,
3076  OperationState &result) {
3077  auto &builder = parser.getBuilder();
3078  llvm::SMLoc typesLoc;
3079  OpAsmParser::OperandType vectorInfo, sourceInfo;
3080  SmallVector<OpAsmParser::OperandType, 8> indexInfo;
3081  SmallVector<Type, 2> types;
3082  OpAsmParser::OperandType maskInfo;
3083  if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
3084  parser.parseOperand(sourceInfo) ||
3086  return failure();
3087  ParseResult hasMask = parser.parseOptionalComma();
3088  if (hasMask.succeeded() && parser.parseOperand(maskInfo))
3089  return failure();
3090  if (parser.parseOptionalAttrDict(result.attributes) ||
3091  parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
3092  return failure();
3093  if (types.size() != 2)
3094  return parser.emitError(typesLoc, "requires two types");
3095  auto indexType = builder.getIndexType();
3096  VectorType vectorType = types[0].dyn_cast<VectorType>();
3097  if (!vectorType)
3098  return parser.emitError(typesLoc, "requires vector type");
3099  ShapedType shapedType = types[1].dyn_cast<ShapedType>();
3100  if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
3101  return parser.emitError(typesLoc, "requires memref or ranked tensor type");
3102  auto permutationAttrName = TransferWriteOp::getPermutationMapAttrName();
3103  auto attr = result.attributes.get(permutationAttrName);
3104  if (!attr) {
3105  auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
3106  result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
3107  }
3108  if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
3109  parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
3110  parser.resolveOperands(indexInfo, indexType, result.operands))
3111  return failure();
3112  if (hasMask.succeeded()) {
3113  if (shapedType.getElementType().dyn_cast<VectorType>())
3114  return parser.emitError(
3115  maskInfo.location, "does not support masks with vector element type");
3116  auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type());
3117  if (parser.resolveOperand(maskInfo, maskType, result.operands))
3118  return failure();
3119  }
3120  result.addAttribute(
3121  TransferWriteOp::getOperandSegmentSizeAttr(),
3122  builder.getI32VectorAttr({1, 1, static_cast<int32_t>(indexInfo.size()),
3123  static_cast<int32_t>(hasMask.succeeded())}));
3124  return failure(shapedType.isa<RankedTensorType>() &&
3125  parser.addTypeToList(shapedType, result.types));
3126 }
3127 
3128 static void print(OpAsmPrinter &p, TransferWriteOp op) {
3129  p << " " << op.vector() << ", " << op.source() << "[" << op.indices() << "]";
3130  if (op.mask())
3131  p << ", " << op.mask();
3132  printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
3133  p << " : " << op.getVectorType() << ", " << op.getShapedType();
3134 }
3135 
3136 static LogicalResult verify(TransferWriteOp op) {
3137  // Consistency of elemental types in shape and vector.
3138  ShapedType shapedType = op.getShapedType();
3139  VectorType vectorType = op.getVectorType();
3140  VectorType maskType = op.getMaskType();
3141  auto permutationMap = op.permutation_map();
3142 
3143  if (llvm::size(op.indices()) != shapedType.getRank())
3144  return op.emitOpError("requires ") << shapedType.getRank() << " indices";
3145 
3146  // We do not allow broadcast dimensions on TransferWriteOps for the moment,
3147  // as the semantics is unclear. This can be revisited later if necessary.
3148  if (op.hasBroadcastDim())
3149  return op.emitOpError("should not have broadcast dimensions");
3150 
3151  if (failed(
3152  verifyTransferOp(cast<VectorTransferOpInterface>(op.getOperation()),
3153  shapedType, vectorType, maskType, permutationMap,
3154  op.in_bounds() ? *op.in_bounds() : ArrayAttr())))
3155  return failure();
3156 
3157  return verifyPermutationMap(permutationMap,
3158  [&op](Twine t) { return op.emitOpError(t); });
3159 }
3160 
3161 /// Fold:
3162 /// ```
3163 /// %t1 = ...
3164 /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
3165 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
3166 /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
3167 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
3168 /// ```
3169 ///
3170 /// into:
3171 ///
3172 /// ```
3173 /// %t0
3174 /// ```
3175 ///
3176 /// The producer of t1 may or may not be DCE'd depending on whether it is a
3177 /// block argument or has side effects.
3178 static LogicalResult foldReadInitWrite(TransferWriteOp write,
3179  ArrayRef<Attribute>,
3180  SmallVectorImpl<OpFoldResult> &results) {
3181  // TODO: support 0-d corner case.
3182  if (write.getTransferRank() == 0)
3183  return failure();
3184  auto rankedTensorType = write.source().getType().dyn_cast<RankedTensorType>();
3185  // If not operating on tensors, bail.
3186  if (!rankedTensorType)
3187  return failure();
3188  // If no read, bail.
3189  auto read = write.vector().getDefiningOp<vector::TransferReadOp>();
3190  if (!read)
3191  return failure();
3192  // TODO: support 0-d corner case.
3193  if (read.getTransferRank() == 0)
3194  return failure();
3195  // For now, only accept minor identity. Future: composition is minor identity.
3196  if (!read.permutation_map().isMinorIdentity() ||
3197  !write.permutation_map().isMinorIdentity())
3198  return failure();
3199  // Bail on mismatching ranks.
3200  if (read.getTransferRank() != write.getTransferRank())
3201  return failure();
3202  // Bail on potential out-of-bounds accesses.
3203  if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
3204  return failure();
3205  // Tensor types must be the same.
3206  if (read.source().getType() != rankedTensorType)
3207  return failure();
3208  // Vector types must be the same.
3209  if (read.getVectorType() != write.getVectorType())
3210  return failure();
3211  // Vector and Tensor shapes must match.
3212  if (read.getVectorType().getShape() != rankedTensorType.getShape())
3213  return failure();
3214  // If any index is nonzero.
3215  auto isNotConstantZero = [](Value v) {
3216  auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>();
3217  return !cstOp || cstOp.value() != 0;
3218  };
3219  if (llvm::any_of(read.indices(), isNotConstantZero) ||
3220  llvm::any_of(write.indices(), isNotConstantZero))
3221  return failure();
3222  // Success.
3223  results.push_back(read.source());
3224  return success();
3225 }
3226 
3227 static bool checkSameValueWAR(vector::TransferReadOp read,
3228  vector::TransferWriteOp write) {
3229  return read.source() == write.source() && read.indices() == write.indices() &&
3230  read.permutation_map() == write.permutation_map() &&
3231  read.getVectorType() == write.getVectorType() && !read.mask() &&
3232  !write.mask();
3233 }
3234 /// Fold transfer_write write after read:
3235 /// ```
3236 /// %t0 = ...
3237 /// %v = vector.transfer_read %t0[%c0...] :
3238 /// tensor<static_sizesxf32>, vector<static_sizesxf32>
3239 /// %t1 = vector.transfer_write %v, %t0[%c0...] :
3240 /// vector<static_sizesxf32>, tensor<static_sizesxf32>
3241 /// ```
3242 ///
3243 /// into:
3244 ///
3245 /// ```
3246 /// %t0
3247 /// ```
3248 static LogicalResult foldWAR(TransferWriteOp write,
3249  SmallVectorImpl<OpFoldResult> &results) {
3250  if (!write.source().getType().isa<RankedTensorType>())
3251  return failure();
3252  auto read = write.vector().getDefiningOp<vector::TransferReadOp>();
3253  if (!read)
3254  return failure();
3255 
3256  if (!checkSameValueWAR(read, write))
3257  return failure();
3258  results.push_back(read.source());
3259  return success();
3260 }
3261 
3262 LogicalResult TransferWriteOp::fold(ArrayRef<Attribute> operands,
3263  SmallVectorImpl<OpFoldResult> &results) {
3264  if (succeeded(foldReadInitWrite(*this, operands, results)))
3265  return success();
3266  if (succeeded(foldWAR(*this, results)))
3267  return success();
3269  return success();
3270  return foldMemRefCast(*this);
3271 }
3272 
3273 Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
3274  return llvm::to_vector<4>(getVectorType().getShape());
3275 }
3276 
3277 void TransferWriteOp::getEffects(
3279  &effects) {
3280  if (getShapedType().isa<MemRefType>())
3281  effects.emplace_back(MemoryEffects::Write::get(), source(),
3283 }
3284 
3285 namespace {
3286 /// Remove dead transfer write from the SSA chain so that it an be eliminated by
3287 /// DCE
3288 /// ```
3289 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
3290 /// : vector<1x4xf32>, tensor<4x4xf32>
3291 /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
3292 /// : vector<1x4xf32>, tensor<4x4xf32>
3293 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
3294 /// : vector<1x4xf32>, tensor<4x4xf32>
3295 /// ```
3296 ///
3297 /// into:
3298 ///
3299 /// ```
3300 /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
3301 /// : vector<1x4xf32>, tensor<4x4xf32>
3302 /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
3303 /// : vector<1x4xf32>, tensor<4x4xf32>
3304 /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
3305 /// : vector<1x4xf32>, tensor<4x4xf32>
3306 /// ```
3307 ///
3308 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
3309 /// any other uses.
3310 class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
3311 public:
3313  LogicalResult matchAndRewrite(TransferWriteOp writeOp,
3314  PatternRewriter &rewriter) const override {
3315  if (!writeOp.getShapedType().isa<RankedTensorType>())
3316  return failure();
3317  vector::TransferWriteOp writeToModify = writeOp;
3318 
3319  auto defWrite = writeOp.source().getDefiningOp<vector::TransferWriteOp>();
3320  while (defWrite) {
3321  if (checkSameValueWAW(writeOp, defWrite)) {
3322  writeToModify.sourceMutable().assign(defWrite.source());
3323  return success();
3324  }
3325  if (!isDisjointTransferIndices(
3326  cast<VectorTransferOpInterface>(defWrite.getOperation()),
3327  cast<VectorTransferOpInterface>(writeOp.getOperation())))
3328  break;
3329  // If the previous write op doesn't have any other use we an safely look
3330  // at the previous store to see if it can be removed.
3331  if (!defWrite->hasOneUse())
3332  break;
3333  writeToModify = defWrite;
3334  defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
3335  }
3336  return failure();
3337  }
3338 };
3339 
3340 /// Fold tensor.insert_slice into vector.transfer_write if the transfer_write
3341 /// could directly write to the insert_slice's destination. E.g.:
3342 ///
3343 /// ```
3344 /// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]}
3345 /// : vector<4x5xf32>, tensor<4x5xf32>
3346 /// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1]
3347 /// : tensor<4x5xf32> into tensor<?x?xf32>
3348 /// ```
3349 /// is rewritten to:
3350 /// ```
3351 /// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]}
3352 /// : vector<4x5xf32>, tensor<?x?xf32>
3353 /// ```
3354 struct FoldInsertSliceIntoTransferWrite
3355  : public OpRewritePattern<tensor::InsertSliceOp> {
3356 public:
3358 
3359  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
3360  PatternRewriter &rewriter) const override {
3361  if (!insertOp.hasUnitStride())
3362  return failure();
3363 
3364  auto xferOp = insertOp.source().getDefiningOp<TransferWriteOp>();
3365  if (!xferOp)
3366  return failure();
3367  // TODO: support 0-d corner case.
3368  if (xferOp.getTransferRank() == 0)
3369  return failure();
3370 
3371  if (xferOp.hasOutOfBoundsDim())
3372  return failure();
3373  if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank())
3374  return failure();
3375  if (xferOp.mask())
3376  return failure();
3377  // Fold only if the TransferWriteOp completely overwrites the `source` with
3378  // a vector. I.e., the result of the TransferWriteOp is a new tensor whose
3379  // content is the data of the vector.
3380  if (!llvm::equal(xferOp.getVectorType().getShape(),
3381  xferOp.getShapedType().getShape()))
3382  return failure();
3383  if (!xferOp.permutation_map().isIdentity())
3384  return failure();
3385 
3386  // Bail on illegal rank-reduction: we need to check that the rank-reduced
3387  // dims are exactly the leading dims. I.e. the following is illegal:
3388  // ```
3389  // %0 = vector.transfer_write %v, %t[0,0], %cst :
3390  // vector<2x4xf32>, tensor<2x4xf32>
3391  // %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] :
3392  // tensor<2x4xf32> into tensor<2x1x4xf32>
3393  // ```
3394  //
3395  // Cannot fold into:
3396  // ```
3397  // %0 = vector.transfer_write %v, %t[0,0,0], %cst :
3398  // vector<2x4xf32>, tensor<2x1x4xf32>
3399  // ```
3400  // For this, check the trailing `vectorRank` dims of the insert_slice result
3401  // tensor match the trailing dims of the inferred result tensor.
3402  int64_t rankReduced =
3403  insertOp.getType().getRank() - insertOp.getSourceType().getRank();
3404  int64_t vectorRank = xferOp.getVectorType().getRank();
3405  RankedTensorType inferredSourceTensorType =
3406  tensor::ExtractSliceOp::inferResultType(
3407  insertOp.getType(), insertOp.getMixedOffsets(),
3408  insertOp.getMixedSizes(), insertOp.getMixedStrides());
3409  auto actualSourceTensorShape = insertOp.getSourceType().getShape();
3410  if (rankReduced > 0 &&
3411  actualSourceTensorShape.take_back(vectorRank) !=
3412  inferredSourceTensorType.getShape().take_back(vectorRank))
3413  return failure();
3414 
3415  SmallVector<Value> indices = getValueOrCreateConstantIndexOp(
3416  rewriter, insertOp.getLoc(), insertOp.getMixedOffsets());
3417  SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
3418  rewriter.replaceOpWithNewOp<TransferWriteOp>(insertOp, xferOp.vector(),
3419  insertOp.dest(), indices,
3420  ArrayRef<bool>{inBounds});
3421  return success();
3422  }
3423 };
3424 } // namespace
3425 
3426 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
3427  MLIRContext *context) {
3428  results.add<FoldWaw, FoldInsertSliceIntoTransferWrite>(context);
3429 }
3430 
3431 //===----------------------------------------------------------------------===//
3432 // LoadOp
3433 //===----------------------------------------------------------------------===//
3434 
3435 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
3436  MemRefType memRefTy) {
3437  if (!isLastMemrefDimUnitStride(memRefTy))
3438  return op->emitOpError("most minor memref dim must have unit stride");
3439  return success();
3440 }
3441 
3442 static LogicalResult verify(vector::LoadOp op) {
3443  VectorType resVecTy = op.getVectorType();
3444  MemRefType memRefTy = op.getMemRefType();
3445 
3446  if (failed(verifyLoadStoreMemRefLayout(op, memRefTy)))
3447  return failure();
3448 
3449  // Checks for vector memrefs.
3450  Type memElemTy = memRefTy.getElementType();
3451  if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
3452  if (memVecTy != resVecTy)
3453  return op.emitOpError("base memref and result vector types should match");
3454  memElemTy = memVecTy.getElementType();
3455  }
3456 
3457  if (resVecTy.getElementType() != memElemTy)
3458  return op.emitOpError("base and result element types should match");
3459  if (llvm::size(op.indices()) != memRefTy.getRank())
3460  return op.emitOpError("requires ") << memRefTy.getRank() << " indices";
3461  return success();
3462 }
3463 
3464 OpFoldResult LoadOp::fold(ArrayRef<Attribute>) {
3465  if (succeeded(foldMemRefCast(*this)))
3466  return getResult();
3467  return OpFoldResult();
3468 }
3469 
3470 //===----------------------------------------------------------------------===//
3471 // StoreOp
3472 //===----------------------------------------------------------------------===//
3473 
3474 static LogicalResult verify(vector::StoreOp op) {
3475  VectorType valueVecTy = op.getVectorType();
3476  MemRefType memRefTy = op.getMemRefType();
3477 
3478  if (failed(verifyLoadStoreMemRefLayout(op, memRefTy)))
3479  return failure();
3480 
3481  // Checks for vector memrefs.
3482  Type memElemTy = memRefTy.getElementType();
3483  if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
3484  if (memVecTy != valueVecTy)
3485  return op.emitOpError(
3486  "base memref and valueToStore vector types should match");
3487  memElemTy = memVecTy.getElementType();
3488  }
3489 
3490  if (valueVecTy.getElementType() != memElemTy)
3491  return op.emitOpError("base and valueToStore element type should match");
3492  if (llvm::size(op.indices()) != memRefTy.getRank())
3493  return op.emitOpError("requires ") << memRefTy.getRank() << " indices";
3494  return success();
3495 }
3496 
3497 LogicalResult StoreOp::fold(ArrayRef<Attribute> operands,
3498  SmallVectorImpl<OpFoldResult> &results) {
3499  return foldMemRefCast(*this);
3500 }
3501 
3502 //===----------------------------------------------------------------------===//
3503 // MaskedLoadOp
3504 //===----------------------------------------------------------------------===//
3505 
3506 static LogicalResult verify(MaskedLoadOp op) {
3507  VectorType maskVType = op.getMaskVectorType();
3508  VectorType passVType = op.getPassThruVectorType();
3509  VectorType resVType = op.getVectorType();
3510  MemRefType memType = op.getMemRefType();
3511 
3512  if (resVType.getElementType() != memType.getElementType())
3513  return op.emitOpError("base and result element type should match");
3514  if (llvm::size(op.indices()) != memType.getRank())
3515  return op.emitOpError("requires ") << memType.getRank() << " indices";
3516  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
3517  return op.emitOpError("expected result dim to match mask dim");
3518  if (resVType != passVType)
3519  return op.emitOpError("expected pass_thru of same type as result type");
3520  return success();
3521 }
3522 
3523 namespace {
3524 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
3525 public:
3527  LogicalResult matchAndRewrite(MaskedLoadOp load,
3528  PatternRewriter &rewriter) const override {
3529  switch (get1DMaskFormat(load.mask())) {
3530  case MaskFormat::AllTrue:
3531  rewriter.replaceOpWithNewOp<vector::LoadOp>(load, load.getType(),
3532  load.base(), load.indices());
3533  return success();
3534  case MaskFormat::AllFalse:
3535  rewriter.replaceOp(load, load.pass_thru());
3536  return success();
3537  case MaskFormat::Unknown:
3538  return failure();
3539  }
3540  llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
3541  }
3542 };
3543 } // namespace
3544 
3545 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3546  MLIRContext *context) {
3547  results.add<MaskedLoadFolder>(context);
3548 }
3549 
3550 OpFoldResult MaskedLoadOp::fold(ArrayRef<Attribute>) {
3551  if (succeeded(foldMemRefCast(*this)))
3552  return getResult();
3553  return OpFoldResult();
3554 }
3555 
3556 //===----------------------------------------------------------------------===//
3557 // MaskedStoreOp
3558 //===----------------------------------------------------------------------===//
3559 
3560 static LogicalResult verify(MaskedStoreOp op) {
3561  VectorType maskVType = op.getMaskVectorType();
3562  VectorType valueVType = op.getVectorType();
3563  MemRefType memType = op.getMemRefType();
3564 
3565  if (valueVType.getElementType() != memType.getElementType())
3566  return op.emitOpError("base and valueToStore element type should match");
3567  if (llvm::size(op.indices()) != memType.getRank())
3568  return op.emitOpError("requires ") << memType.getRank() << " indices";
3569  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
3570  return op.emitOpError("expected valueToStore dim to match mask dim");
3571  return success();
3572 }
3573 
3574 namespace {
3575 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
3576 public:
3578  LogicalResult matchAndRewrite(MaskedStoreOp store,
3579  PatternRewriter &rewriter) const override {
3580  switch (get1DMaskFormat(store.mask())) {
3581  case MaskFormat::AllTrue:
3582  rewriter.replaceOpWithNewOp<vector::StoreOp>(
3583  store, store.valueToStore(), store.base(), store.indices());
3584  return success();
3585  case MaskFormat::AllFalse:
3586  rewriter.eraseOp(store);
3587  return success();
3588  case MaskFormat::Unknown:
3589  return failure();
3590  }
3591  llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
3592  }
3593 };
3594 } // namespace
3595 
3596 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3597  MLIRContext *context) {
3598  results.add<MaskedStoreFolder>(context);
3599 }
3600 
3601 LogicalResult MaskedStoreOp::fold(ArrayRef<Attribute> operands,
3602  SmallVectorImpl<OpFoldResult> &results) {
3603  return foldMemRefCast(*this);
3604 }
3605 
3606 //===----------------------------------------------------------------------===//
3607 // GatherOp
3608 //===----------------------------------------------------------------------===//
3609 
3610 static LogicalResult verify(GatherOp op) {
3611  VectorType indVType = op.getIndexVectorType();
3612  VectorType maskVType = op.getMaskVectorType();
3613  VectorType resVType = op.getVectorType();
3614  MemRefType memType = op.getMemRefType();
3615 
3616  if (resVType.getElementType() != memType.getElementType())
3617  return op.emitOpError("base and result element type should match");
3618  if (llvm::size(op.indices()) != memType.getRank())
3619  return op.emitOpError("requires ") << memType.getRank() << " indices";
3620  if (resVType.getDimSize(0) != indVType.getDimSize(0))
3621  return op.emitOpError("expected result dim to match indices dim");
3622  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
3623  return op.emitOpError("expected result dim to match mask dim");
3624  if (resVType != op.getPassThruVectorType())
3625  return op.emitOpError("expected pass_thru of same type as result type");
3626  return success();
3627 }
3628 
3629 namespace {
3630 class GatherFolder final : public OpRewritePattern<GatherOp> {
3631 public:
3633  LogicalResult matchAndRewrite(GatherOp gather,
3634  PatternRewriter &rewriter) const override {
3635  switch (get1DMaskFormat(gather.mask())) {
3636  case MaskFormat::AllTrue:
3637  return failure(); // no unmasked equivalent
3638  case MaskFormat::AllFalse:
3639  rewriter.replaceOp(gather, gather.pass_thru());
3640  return success();
3641  case MaskFormat::Unknown:
3642  return failure();
3643  }
3644  llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
3645  }
3646 };
3647 } // namespace
3648 
3649 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
3650  MLIRContext *context) {
3651  results.add<GatherFolder>(context);
3652 }
3653 
3654 //===----------------------------------------------------------------------===//
3655 // ScatterOp
3656 //===----------------------------------------------------------------------===//
3657 
3658 static LogicalResult verify(ScatterOp op) {
3659  VectorType indVType = op.getIndexVectorType();
3660  VectorType maskVType = op.getMaskVectorType();
3661  VectorType valueVType = op.getVectorType();
3662  MemRefType memType = op.getMemRefType();
3663 
3664  if (valueVType.getElementType() != memType.getElementType())
3665  return op.emitOpError("base and valueToStore element type should match");
3666  if (llvm::size(op.indices()) != memType.getRank())
3667  return op.emitOpError("requires ") << memType.getRank() << " indices";
3668  if (valueVType.getDimSize(0) != indVType.getDimSize(0))
3669  return op.emitOpError("expected valueToStore dim to match indices dim");
3670  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
3671  return op.emitOpError("expected valueToStore dim to match mask dim");
3672  return success();
3673 }
3674 
3675 namespace {
3676 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
3677 public:
3679  LogicalResult matchAndRewrite(ScatterOp scatter,
3680  PatternRewriter &rewriter) const override {
3681  switch (get1DMaskFormat(scatter.mask())) {
3682  case MaskFormat::AllTrue:
3683  return failure(); // no unmasked equivalent
3684  case MaskFormat::AllFalse:
3685  rewriter.eraseOp(scatter);
3686  return success();
3687  case MaskFormat::Unknown:
3688  return failure();
3689  }
3690  llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
3691  }
3692 };
3693 } // namespace
3694 
3695 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
3696  MLIRContext *context) {
3697  results.add<ScatterFolder>(context);
3698 }
3699 
3700 //===----------------------------------------------------------------------===//
3701 // ExpandLoadOp
3702 //===----------------------------------------------------------------------===//
3703 
3704 static LogicalResult verify(ExpandLoadOp op) {
3705  VectorType maskVType = op.getMaskVectorType();
3706  VectorType passVType = op.getPassThruVectorType();
3707  VectorType resVType = op.getVectorType();
3708  MemRefType memType = op.getMemRefType();
3709 
3710  if (resVType.getElementType() != memType.getElementType())
3711  return op.emitOpError("base and result element type should match");
3712  if (llvm::size(op.indices()) != memType.getRank())
3713  return op.emitOpError("requires ") << memType.getRank() << " indices";
3714  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
3715  return op.emitOpError("expected result dim to match mask dim");
3716  if (resVType != passVType)
3717  return op.emitOpError("expected pass_thru of same type as result type");
3718  return success();
3719 }
3720 
3721 namespace {
3722 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
3723 public:
3725  LogicalResult matchAndRewrite(ExpandLoadOp expand,
3726  PatternRewriter &rewriter) const override {
3727  switch (get1DMaskFormat(expand.mask())) {
3728  case MaskFormat::AllTrue:
3729  rewriter.replaceOpWithNewOp<vector::LoadOp>(
3730  expand, expand.getType(), expand.base(), expand.indices());
3731  return success();
3732  case MaskFormat::AllFalse:
3733  rewriter.replaceOp(expand, expand.pass_thru());
3734  return success();
3735  case MaskFormat::Unknown:
3736  return failure();
3737  }
3738  llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
3739  }
3740 };
3741 } // namespace
3742 
3743 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3744  MLIRContext *context) {
3745  results.add<ExpandLoadFolder>(context);
3746 }
3747 
3748 //===----------------------------------------------------------------------===//
3749 // CompressStoreOp
3750 //===----------------------------------------------------------------------===//
3751 
3752 static LogicalResult verify(CompressStoreOp op) {
3753  VectorType maskVType = op.getMaskVectorType();
3754  VectorType valueVType = op.getVectorType();
3755  MemRefType memType = op.getMemRefType();
3756 
3757  if (valueVType.getElementType() != memType.getElementType())
3758  return op.emitOpError("base and valueToStore element type should match");
3759  if (llvm::size(op.indices()) != memType.getRank())
3760  return op.emitOpError("requires ") << memType.getRank() << " indices";
3761  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
3762  return op.emitOpError("expected valueToStore dim to match mask dim");
3763  return success();
3764 }
3765 
3766 namespace {
3767 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
3768 public:
3770  LogicalResult matchAndRewrite(CompressStoreOp compress,
3771  PatternRewriter &rewriter) const override {
3772  switch (get1DMaskFormat(compress.mask())) {
3773  case MaskFormat::AllTrue:
3774  rewriter.replaceOpWithNewOp<vector::StoreOp>(
3775  compress, compress.valueToStore(), compress.base(),
3776  compress.indices());
3777  return success();
3778  case MaskFormat::AllFalse:
3779  rewriter.eraseOp(compress);
3780  return success();
3781  case MaskFormat::Unknown:
3782  return failure();
3783  }
3784  llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
3785  }
3786 };
3787 } // namespace
3788 
3789 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3790  MLIRContext *context) {
3791  results.add<CompressStoreFolder>(context);
3792 }
3793 
3794 //===----------------------------------------------------------------------===//
3795 // ShapeCastOp
3796 //===----------------------------------------------------------------------===//
3797 
3798 /// Returns true if each element of 'a' is equal to the product of a contiguous
3799 /// sequence of the elements of 'b'. Returns false otherwise.
3800 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
3801  unsigned rankA = a.size();
3802  unsigned rankB = b.size();
3803  assert(rankA < rankB);
3804 
3805  unsigned i = 0;
3806  unsigned j = 0;
3807  while (i < rankA && j < rankB) {
3808  int64_t dimA = a[i];
3809  int64_t dimB = 1;
3810  while (dimB < dimA && j < rankB)
3811  dimB *= b[j++];
3812  if (dimA != dimB)
3813  break;
3814  ++i;
3815 
3816  // Handle the case when trailing dimensions are of size 1.
3817  // Include them into the contiguous sequence.
3818  auto isOne = [](int64_t v) { return v == 1; };
3819  if (i < rankA && llvm::all_of(a.slice(i), isOne))
3820  i = rankA;
3821  if (j < rankB && llvm::all_of(b.slice(j), isOne))
3822  j = rankB;
3823  }
3824 
3825  return i == rankA && j == rankB;
3826 }
3827 
3828 static LogicalResult verifyVectorShapeCast(Operation *op,
3829  VectorType sourceVectorType,
3830  VectorType resultVectorType) {
3831  // Check that element type is the same.
3832  if (sourceVectorType.getElementType() != resultVectorType.getElementType())
3833  return op->emitOpError("source/result vectors must have same element type");
3834  auto sourceShape = sourceVectorType.getShape();
3835  auto resultShape = resultVectorType.getShape();
3836 
3837  // Check that product of source dim sizes matches product of result dim sizes.
3838  int64_t sourceDimProduct = std::accumulate(
3839  sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
3840  int64_t resultDimProduct = std::accumulate(
3841  resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
3842  if (sourceDimProduct != resultDimProduct)
3843  return op->emitOpError("source/result number of elements must match");
3844 
3845  // Check that expanding/contracting rank cases.
3846  unsigned sourceRank = sourceVectorType.getRank();
3847  unsigned resultRank = resultVectorType.getRank();
3848  if (sourceRank < resultRank) {
3849  if (!isValidShapeCast(sourceShape, resultShape))
3850  return op->emitOpError("invalid shape cast");
3851  } else if (sourceRank > resultRank) {
3852  if (!isValidShapeCast(resultShape, sourceShape))
3853  return op->emitOpError("invalid shape cast");
3854  }
3855  return success();
3856 }
3857 
3858 static LogicalResult verify(ShapeCastOp op) {
3859  auto sourceVectorType = op.source().getType().dyn_cast_or_null<VectorType>();
3860  auto resultVectorType = op.result().getType().dyn_cast_or_null<VectorType>();
3861 
3862  // Check if source/result are of vector type.
3863  if (sourceVectorType && resultVectorType)
3864  return verifyVectorShapeCast(op, sourceVectorType, resultVectorType);
3865 
3866  return success();
3867 }
3868 
3869 OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
3870  // Nop shape cast.
3871  if (source().getType() == result().getType())
3872  return source();
3873 
3874  // Canceling shape casts.
3875  if (auto otherOp = source().getDefiningOp<ShapeCastOp>()) {
3876  if (result().getType() == otherOp.source().getType())
3877  return otherOp.source();
3878 
3879  // Only allows valid transitive folding.
3880  VectorType srcType = otherOp.source().getType().cast<VectorType>();
3881  VectorType resultType = getResult().getType().cast<VectorType>();
3882  if (srcType.getRank() < resultType.getRank()) {
3883  if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
3884  return {};
3885  } else if (srcType.getRank() > resultType.getRank()) {
3886  if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
3887  return {};
3888  } else {
3889  return {};
3890  }
3891 
3892  setOperand(otherOp.source());
3893  return getResult();
3894  }
3895  return {};
3896 }
3897 
3898 namespace {
3899 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
3900 class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
3901 public:
3903 
3904  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
3905  PatternRewriter &rewriter) const override {
3906  auto constantOp = shapeCastOp.source().getDefiningOp<arith::ConstantOp>();
3907  if (!constantOp)
3908  return failure();
3909  // Only handle splat for now.
3910  auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
3911  if (!dense)
3912  return failure();
3913  auto newAttr =
3914  DenseElementsAttr::get(shapeCastOp.getType().cast<VectorType>(),
3915  dense.getSplatValue<Attribute>());
3916  rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
3917  return success();
3918  }
3919 };
3920 
3921 } // namespace
3922 
3923 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
3924  MLIRContext *context) {
3925  // Pattern to rewrite a ShapeCastOp(ConstantOp) -> ConstantOp.
3926  results.add<ShapeCastConstantFolder>(context);
3927 }
3928 
3929 //===----------------------------------------------------------------------===//
3930 // VectorBitCastOp
3931 //===----------------------------------------------------------------------===//
3932 
3933 static LogicalResult verify(BitCastOp op) {
3934  auto sourceVectorType = op.getSourceVectorType();
3935  auto resultVectorType = op.getResultVectorType();
3936 
3937  for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
3938  if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
3939  return op.emitOpError("dimension size mismatch at: ") << i;
3940  }
3941 
3942  DataLayout dataLayout = DataLayout::closest(op);
3943  auto sourceElementBits =
3944  dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
3945  auto resultElementBits =
3946  dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
3947 
3948  if (sourceVectorType.getRank() == 0) {
3949  if (sourceElementBits != resultElementBits)
3950  return op.emitOpError("source/result bitwidth of the 0-D vector element "
3951  "types must be equal");
3952  } else if (sourceElementBits * sourceVectorType.getShape().back() !=
3953  resultElementBits * resultVectorType.getShape().back()) {
3954  return op.emitOpError(
3955  "source/result bitwidth of the minor 1-D vectors must be equal");
3956  }
3957 
3958  return success();
3959 }
3960 
3961 OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) {
3962  // Nop cast.
3963  if (source().getType() == result().getType())
3964  return source();
3965 
3966  // Canceling bitcasts.
3967  if (auto otherOp = source().getDefiningOp<BitCastOp>())
3968  if (result().getType() == otherOp.source().getType())
3969  return otherOp.source();
3970 
3971  Attribute sourceConstant = operands.front();
3972  if (!sourceConstant)
3973  return {};
3974 
3975  Type srcElemType = getSourceVectorType().getElementType();
3976  Type dstElemType = getResultVectorType().getElementType();
3977 
3978  if (auto floatPack = sourceConstant.dyn_cast<DenseFPElementsAttr>()) {
3979  if (floatPack.isSplat()) {
3980  auto splat = floatPack.getSplatValue<FloatAttr>();
3981 
3982  // Casting fp16 into fp32.
3983  if (srcElemType.isF16() && dstElemType.isF32()) {
3984  uint32_t bits = static_cast<uint32_t>(
3985  splat.getValue().bitcastToAPInt().getZExtValue());
3986  // Duplicate the 16-bit pattern.
3987  bits = (bits << 16) | (bits & 0xffff);
3988  APInt intBits(32, bits);
3989  APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
3990  return DenseElementsAttr::get(getResultVectorType(), floatBits);
3991  }
3992  }
3993  }
3994 
3995  return {};
3996 }
3997 
3998 //===----------------------------------------------------------------------===//
3999 // TypeCastOp
4000 //===----------------------------------------------------------------------===//
4001 
4002 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
4003  auto vectorType = memRefType.getElementType().dyn_cast<VectorType>();
4004  SmallVector<int64_t, 8> res(memRefType.getShape().begin(),
4005  memRefType.getShape().end());
4006  if (vectorType)
4007  res.append(vectorType.getShape().begin(), vectorType.getShape().end());
4008  return res;
4009 }
4010 
4011 /// Build the canonical memRefType with a single vector.
4012 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
4013 void TypeCastOp::build(OpBuilder &builder, OperationState &result,
4014  Value source) {
4015  result.addOperands(source);
4016  MemRefType memRefType = source.getType().cast<MemRefType>();
4017  VectorType vectorType =
4018  VectorType::get(extractShape(memRefType),
4020  result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
4021  memRefType.getMemorySpace()));
4022 }
4023 
4024 static LogicalResult verify(TypeCastOp op) {
4025  MemRefType canonicalType = canonicalizeStridedLayout(op.getMemRefType());
4026  if (!canonicalType.getLayout().isIdentity())
4027  return op.emitOpError(
4028  "expects operand to be a memref with identity layout");
4029  if (!op.getResultMemRefType().getLayout().isIdentity())
4030  return op.emitOpError("expects result to be a memref with identity layout");
4031  if (op.getResultMemRefType().getMemorySpace() !=
4032  op.getMemRefType().getMemorySpace())
4033  return op.emitOpError("expects result in same memory space");
4034 
4035  auto sourceType = op.getMemRefType();
4036  auto resultType = op.getResultMemRefType();
4037  if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
4039  return op.emitOpError(
4040  "expects result and operand with same underlying scalar type: ")
4041  << resultType;
4042  if (extractShape(sourceType) != extractShape(resultType))
4043  return op.emitOpError(
4044  "expects concatenated result and operand shapes to be equal: ")
4045  << resultType;
4046  return success();
4047 }
4048 
4049 //===----------------------------------------------------------------------===//
4050 // TransposeOp
4051 //===----------------------------------------------------------------------===//
4052 
4053 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
4054  Value vector, ArrayRef<int64_t> transp) {
4055  VectorType vt = vector.getType().cast<VectorType>();
4056  SmallVector<int64_t, 4> transposedShape(vt.getRank());
4057  for (unsigned i = 0; i < transp.size(); ++i)
4058  transposedShape[i] = vt.getShape()[transp[i]];
4059 
4060  result.addOperands(vector);
4061  result.addTypes(VectorType::get(transposedShape, vt.getElementType()));
4062  result.addAttribute(getTranspAttrName(), builder.getI64ArrayAttr(transp));
4063 }
4064 
4065 // Eliminates transpose operations, which produce values identical to their
4066 // input values. This happens when the dimensions of the input vector remain in
4067 // their original order after the transpose operation.
4068 OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) {
4069  SmallVector<int64_t, 4> transp;
4070  getTransp(transp);
4071 
4072  // Check if the permutation of the dimensions contains sequential values:
4073  // {0, 1, 2, ...}.
4074  for (int64_t i = 0, e = transp.size(); i < e; i++) {
4075  if (transp[i] != i)
4076  return {};
4077  }
4078 
4079  return vector();
4080 }
4081 
4082 static LogicalResult verify(vector::TransposeOp op) {
4083  VectorType vectorType = op.getVectorType();
4084  VectorType resultType = op.getResultType();
4085  int64_t rank = resultType.getRank();
4086  if (vectorType.getRank() != rank)
4087  return op.emitOpError("vector result rank mismatch: ") << rank;
4088  // Verify transposition array.
4089  auto transpAttr = op.transp().getValue();
4090  int64_t size = transpAttr.size();
4091  if (rank != size)
4092  return op.emitOpError("transposition length mismatch: ") << size;
4093  SmallVector<bool, 8> seen(rank, false);
4094  for (const auto &ta : llvm::enumerate(transpAttr)) {
4095  int64_t i = ta.value().cast<IntegerAttr>().getInt();
4096  if (i < 0 || i >= rank)
4097  return op.emitOpError("transposition index out of range: ") << i;
4098  if (seen[i])
4099  return op.emitOpError("duplicate position index: ") << i;
4100  seen[i] = true;
4101  if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i))
4102  return op.emitOpError("dimension size mismatch at: ") << i;
4103  }
4104  return success();
4105 }
4106 
4107 namespace {
4108 
4109 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
4110 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
4111 public:
4113 
4114  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
4115  PatternRewriter &rewriter) const override {
4116  // Wrapper around vector::TransposeOp::getTransp() for cleaner code.
4117  auto getPermutation = [](vector::TransposeOp transpose) {
4118  SmallVector<int64_t, 4> permutation;
4119  transpose.getTransp(permutation);
4120  return permutation;
4121  };
4122 
4123  // Composes two permutations: result[i] = permutation1[permutation2[i]].
4124  auto composePermutations = [](ArrayRef<int64_t> permutation1,
4125  ArrayRef<int64_t> permutation2) {
4126  SmallVector<int64_t, 4> result;
4127  for (auto index : permutation2)
4128  result.push_back(permutation1[index]);
4129  return result;
4130  };
4131 
4132  // Return if the input of 'transposeOp' is not defined by another transpose.
4133  vector::TransposeOp parentTransposeOp =
4134  transposeOp.vector().getDefiningOp<vector::TransposeOp>();
4135  if (!parentTransposeOp)
4136  return failure();
4137 
4138  SmallVector<int64_t, 4> permutation = composePermutations(
4139  getPermutation(parentTransposeOp), getPermutation(transposeOp));
4140  // Replace 'transposeOp' with a new transpose operation.
4141  rewriter.replaceOpWithNewOp<vector::TransposeOp>(
4142  transposeOp, transposeOp.getResult().getType(),
4143  parentTransposeOp.vector(),
4144  vector::getVectorSubscriptAttr(rewriter, permutation));
4145  return success();
4146  }
4147 };
4148 
4149 } // namespace
4150 
4151 void vector::TransposeOp::getCanonicalizationPatterns(
4152  RewritePatternSet &results, MLIRContext *context) {
4153  results.add<TransposeFolder>(context);
4154 }
4155 
4156 void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
4157  populateFromInt64AttrArray(transp(), results);
4158 }
4159 
4160 //===----------------------------------------------------------------------===//
4161 // ConstantMaskOp
4162 //===----------------------------------------------------------------------===//
4163 
4164 static LogicalResult verify(ConstantMaskOp &op) {
4165  auto resultType = op.getResult().getType().cast<VectorType>();
4166  // Check the corner case of 0-D vectors first.
4167  if (resultType.getRank() == 0) {
4168  if (op.mask_dim_sizes().size() != 1)
4169  return op->emitError("array attr must have length 1 for 0-D vectors");
4170  auto dim = op.mask_dim_sizes()[0].cast<IntegerAttr>().getInt();
4171  if (dim != 0 && dim != 1)
4172  return op->emitError(
4173  "mask dim size must be either 0 or 1 for 0-D vectors");
4174  return success();
4175  }
4176 
4177  // Verify that array attr size matches the rank of the vector result.
4178  if (static_cast<int64_t>(op.mask_dim_sizes().size()) != resultType.getRank())
4179  return op.emitOpError(
4180  "must specify array attr of size equal vector result rank");
4181  // Verify that each array attr element is in bounds of corresponding vector
4182  // result dimension size.
4183  auto resultShape = resultType.getShape();
4184  SmallVector<int64_t, 4> maskDimSizes;
4185  for (const auto &it : llvm::enumerate(op.mask_dim_sizes())) {
4186  int64_t attrValue = it.value().cast<IntegerAttr>().getInt();
4187  if (attrValue < 0 || attrValue > resultShape[it.index()])
4188  return op.emitOpError(
4189  "array attr of size out of bounds of vector result dimension size");
4190  maskDimSizes.push_back(attrValue);
4191  }
4192  // Verify that if one mask dim size is zero, they all should be zero (because
4193  // the mask region is a conjunction of each mask dimension interval).
4194  bool anyZeros = llvm::is_contained(maskDimSizes, 0);
4195  bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
4196  if (anyZeros && !allZeros)
4197  return op.emitOpError("expected all mask dim sizes to be zeros, "
4198  "as a result of conjunction with zero mask dim");
4199  return success();
4200 }
4201 
4202 //===----------------------------------------------------------------------===//
4203 // CreateMaskOp
4204 //===----------------------------------------------------------------------===//
4205 
4206 static LogicalResult verify(CreateMaskOp op) {
4207  auto vectorType = op.getResult().getType().cast<VectorType>();
4208  // Verify that an operand was specified for each result vector each dimension.
4209  if (vectorType.getRank() == 0) {
4210  if (op->getNumOperands() != 1)
4211  return op.emitOpError(
4212  "must specify exactly one operand for 0-D create_mask");
4213  } else if (op.getNumOperands() !=
4214  op.getResult().getType().cast<VectorType>().getRank()) {
4215  return op.emitOpError(
4216  "must specify an operand for each result vector dimension");
4217  }
4218  return success();
4219 }
4220 
4221 namespace {
4222 
4223 // Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
4224 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
4225 public:
4227 
4228  LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
4229  PatternRewriter &rewriter) const override {
4230  // Return if any of 'createMaskOp' operands are not defined by a constant.
4231  auto isNotDefByConstant = [](Value operand) {
4232  return !isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp());
4233  };
4234  if (llvm::any_of(createMaskOp.operands(), isNotDefByConstant))
4235  return failure();
4236  // Gather constant mask dimension sizes.
4237  SmallVector<int64_t, 4> maskDimSizes;
4238  for (auto it : llvm::zip(createMaskOp.operands(),
4239  createMaskOp.getType().getShape())) {
4240  auto *defOp = std::get<0>(it).getDefiningOp();
4241  int64_t maxDimSize = std::get<1>(it);
4242  int64_t dimSize = cast<arith::ConstantIndexOp>(defOp).value();
4243  dimSize = std::min(dimSize, maxDimSize);
4244  // If one of dim sizes is zero, set all dims to zero.
4245  if (dimSize <= 0) {
4246  maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
4247  break;
4248  }
4249  maskDimSizes.push_back(dimSize);
4250  }
4251  // Replace 'createMaskOp' with ConstantMaskOp.
4252  rewriter.replaceOpWithNewOp<ConstantMaskOp>(
4253  createMaskOp, createMaskOp.getResult().getType(),
4254  vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
4255  return success();
4256  }
4257 };
4258 
4259 } // namespace
4260 
4261 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
4262  MLIRContext *context) {
4263  results.add<CreateMaskFolder>(context);
4264 }
4265 
4267  RewritePatternSet &patterns) {
4268  patterns
4269  .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
4270  ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
4271  StridedSliceConstantMaskFolder, TransposeFolder>(
4272  patterns.getContext());
4273 }
4274 
4275 #define GET_OP_CLASSES
4276 #include "mlir/Dialect/Vector/VectorOps.cpp.inc"
virtual ParseResult parseOperand(OperandType &result)=0
Parse a single operand.
This is the representation of an operand reference.
Include the generated interface declarations.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context)
This parses a single MLIR attribute to an MLIR context if it was valid.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
Definition: VectorOps.cpp:362
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
bool isF32() const
Definition: Types.cpp:23
static Value foldExtractFromShapeCast(ExtractOp extractOp)
Definition: VectorOps.cpp:1224
ParseResult resolveOperands(ArrayRef< OperandType > operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:673
An attribute that represents a reference to a dense float vector or tensor object.
MLIRContext * getContext() const
Definition: Builders.h:54
constexpr StringRef getParallelIteratorTypeName()
Use to encode that a particular iterator type has parallel semantics.
U cast() const
Definition: Attributes.h:123
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:444
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
static ParseResult parseContractionOp(OpAsmParser &parser, OperationState &result)
Definition: VectorOps.cpp:439
static Value min(ImplicitLocOpBuilder &builder, Value a, Value b)
unsigned getNumSymbols() const
Definition: AffineMap.cpp:298
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:264
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:516
int64_t getValue() const
Definition: AffineExpr.cpp:508
CombiningKind getKind() const
Definition: VectorOps.cpp:153
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a single result from folding an operation.
Definition: OpDefinition.h:244
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:457
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:89
Value getOperand(unsigned idx)
Definition: Operation.h:219
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Definition: VectorOps.cpp:798
LogicalResult verify(Operation *op)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:353
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
Definition: VectorOps.cpp:2826
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
This is a utility allocator used to allocate memory for instances of derived types.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:205
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, StringRef targetIteratorTypeName, MLIRContext *context)
Definition: VectorOps.cpp:695
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
Definition: VectorOps.cpp:1283
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
Definition: VectorOps.cpp:507
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:639
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
Definition: VectorOps.cpp:687
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d)
Helper method to apply dimension ordering permutation.
virtual ParseResult parseTrailingOperandList(SmallVectorImpl< OperandType > &result, int requiredOperandCount=-1, Delimiter delimiter=Delimiter::None)=0
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
bool isLastMemrefDimUnitStride(MemRefType type)
Return true if the last dimension of the MemRefType has unit stride.
Definition: VectorOps.cpp:115
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:220
constexpr StringRef getIteratorTypesAttrName()
Attribute name for the StrArrayAttr which encodes the type of a structured op&#39;s iterators.
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
Definition: VectorOps.cpp:92
An integer constant appearing in affine expression.
Definition: AffineExpr.h:232
virtual ParseResult parseComma()=0
Parse a , token.
virtual llvm::SMLoc getNameLoc() const =0
Return the location of the original name token.
static constexpr const bool value
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
SmallVector< Value, 4 > operands
static ParseResult parseReductionOp(OpAsmParser &parser, OperationState &result)
Definition: VectorOps.cpp:333
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
virtual ParseResult parseOperandList(SmallVectorImpl< OperandType > &result, int requiredOperandCount=-1, Delimiter delimiter=Delimiter::None)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter...
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:315
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< int, int > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:1561
void map(Block *from, Block *to)
Inserts a new mapping for &#39;from&#39; to &#39;to&#39;.
LogicalResult emitOptionalError(Optional< Location > loc, Args &&... args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:464
unsigned getNumInputs() const
Definition: AffineMap.cpp:306
void assign(const_iterator in_start, const_iterator in_end)
Replaces the attributes with new list of attributes.
static DefaultResource * get()
Returns a unique instance for the given effect class.
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:258
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:252
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector...
Definition: VectorOps.h:46
T * allocate()
Allocate an instance of the provided type.
U dyn_cast() const
Definition: AffineExpr.h:281
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
virtual ParseResult resolveOperand(const OperandType &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
Definition: VectorOps.cpp:1937
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
Definition: VectorOps.cpp:2553
virtual ParseResult parseGreater()=0
Parse a &#39;>&#39; token.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
void addOperands(ValueRange newOperands)
virtual llvm::SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
Definition: VectorOps.cpp:1464
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:530
U dyn_cast() const
Definition: Types.h:244
static LogicalResult foldMemRefCast(Operation *op)
This is a common class used for patterns of the form someop(memrefcast) -> someop It folds the source...
Definition: VectorOps.cpp:2801
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:483
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
bool isF16() const
Definition: Types.cpp:22
Attributes are known-constant values of operations.
Definition: Attributes.h:24
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
Definition: VectorOps.cpp:957
Base type for affine expression.
Definition: AffineExpr.h:68
static LogicalResult foldTensorCast(Operation *op)
Definition: VectorOps.cpp:2813
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
Definition: VectorOps.cpp:1982
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, AffineMap permutationMap, ArrayAttr inBounds)
Definition: VectorOps.cpp:2582
void addTypes(ArrayRef< Type > newTypes)
IntegerType getI1Type()
Definition: Builders.cpp:50
ParseResult parseKeyword(StringRef keyword, const Twine &msg="")
Parse a given keyword.
virtual ParseResult parseLess()=0
Parse a &#39;<&#39; token.
void print(AsmPrinter &p) const
Definition: VectorOps.cpp:173
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
unsigned getNumResults() const
Definition: AffineMap.cpp:302
MaskFormat
Helper enum to classify mask value.
Definition: VectorOps.cpp:46
This represents an operation in an abstracted form, suitable for use with the builder APIs...
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns)
Collect a set of vector-to-vector canonicalization patterns.
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:38
U cast() const
Definition: AffineExpr.h:291
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
Definition: VectorOps.cpp:949
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:311
This class represents a specific instance of an effect.
Eliminates identifier at the specified position using Fourier-Motzkin variable elimination.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:491
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:91
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
Definition: VectorOps.cpp:136
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
Definition: VectorOps.cpp:1195
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition: Builders.cpp:209
This base class exposes generic asm parser hooks, usable across the various derived parsers...
bool isa() const
Definition: AffineExpr.h:270
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
unsigned getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:320
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
An attribute that specifies the combining function for vector.contract, and vector.reduction.
Definition: VectorOps.h:113
SmallVector< int64_t, 4 > delinearize(ArrayRef< int64_t > strides, int64_t linearIndex)
Given the strides together with a linear index in the dimension space, returns the vector-space offse...
Definition: VectorUtils.cpp:75
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
Definition: VectorOps.cpp:1902
MLIRContext * getContext() const
Get the context held by this operation state.
NamedAttrList attributes
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:266
virtual InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values...
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
Definition: Builders.h:49
constexpr StringRef getIndexingMapsAttrName()
Attribute name for the AffineArrayAttr which encodes the relationship between a structured op iterato...
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Definition: AffineMap.cpp:654
Type getType() const
Return the type of this value.
Definition: Value.h:117
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
IndexType getIndexType()
Definition: Builders.cpp:48
static constexpr const CombiningKind combiningKindsList[]
Definition: VectorOps.cpp:157
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
Definition: VectorOps.cpp:2843
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
Definition: VectorOps.cpp:2198
U dyn_cast() const
Definition: Attributes.h:117
bool succeeded() const
Returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:41
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
Definition: VectorOps.cpp:1960
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:78
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
constexpr StringRef getStridesAttrName()
Attribute name for the StrArrayAttr which encodes the value of strides.
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
Definition: VectorOps.cpp:1917
virtual ParseResult parseType(Type &result)=0
Parse a type.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Definition: VectorOps.cpp:252
static Value foldExtractStridedOpFromInsertChain(ExtractOp op)
Fold extract_op fed from a chain of insertStridedSlice ops.
Definition: VectorOps.cpp:1323
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
AffineMap calculateImplicitMap(MapOp op)
Definition: VectorOps.cpp:1532
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:109
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
This class represents an operand of an operation.
Definition: Value.h:249
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
Definition: VectorOps.cpp:518
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:85
Base storage class appearing in an attribute.
U cast() const
Definition: Value.h:107
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Definition: VectorOps.cpp:248
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static void print(OpAsmPrinter &p, ReductionOp op)
Definition: VectorOps.cpp:355
static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result)
Definition: VectorOps.cpp:902
This base class exposes generic asm printer hooks, usable across the various derived printers...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if it is present.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
static CombiningKindAttr get(CombiningKind kind, MLIRContext *context)
Definition: VectorOps.cpp:148
static ParseResult parseTransferReadOp(OpAsmParser &parser, OperationState &result)
Definition: VectorOps.cpp:2695
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "&#39;dim&#39; op " which is convenient for verifiers...
Definition: Operation.cpp:518
Builder & setElementType(Type newElementType)
Definition: BuiltinTypes.h:286
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:54
static ParseResult parseOuterProductOp(OpAsmParser &parser, OperationState &result)
Definition: VectorOps.cpp:2052
bool isa() const
Definition: Types.h:234
constexpr StringRef getReductionIteratorTypeName()
Use to encode that a particular iterator type has reduction semantics.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:235
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:61
This class represents success/failure for operation parsing.
Definition: OpDefinition.h:36
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
static Attribute parse(AsmParser &parser, Type type)
Definition: VectorOps.cpp:183
This class helps build Operations.
Definition: Builders.h:177
static Type inferExtractOpResultType(VectorType vectorType, ArrayAttr position)
Definition: VectorOps.cpp:869
This class provides an abstraction over the different types of ranges over Values.
Return a fused vector::ContractionOp which represents a patterns such as:
Definition: VectorOps.cpp:795
static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result)
Definition: VectorOps.cpp:1703
VectorType transferMaskType(VectorType vecType, AffineMap map)
Given the vector type and the permutation map of a vector transfer op, compute the expected mask type...
bool operator==(const KeyTy &key) const
Definition: VectorOps.cpp:134
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:201
MLIRContext * getContext() const
Definition: PatternMatch.h:906
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type. ...
Definition: FoldUtils.cpp:50
Square brackets surrounding zero or more operands.
An attribute that represents a reference to a dense integer vector or tensor object.
U cast() const
Definition: Types.h:250
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
Definition: VectorOps.cpp:2668
The main mechanism for performing data layout queries.
static Value max(ImplicitLocOpBuilder &builder, Value a, Value b)
SmallVector< Type, 4 > types
Types of the results of this operation.
static MaskFormat get1DMaskFormat(Value mask)
Helper method to classify a 1-D mask value.
Definition: VectorOps.cpp:56
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Computes and returns the linearized index of &#39;offsets&#39; w.r.t. &#39;basis&#39;.
Definition: VectorUtils.cpp:67
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition: Builders.cpp:246