MLIR  21.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
30 #include "mlir/IR/AffineMap.h"
31 #include "mlir/IR/Attributes.h"
32 #include "mlir/IR/Builders.h"
35 #include "mlir/IR/Matchers.h"
38 #include "mlir/IR/PatternMatch.h"
41 
42 #include "llvm/ADT/DenseMap.h"
43 #include "llvm/ADT/STLExtras.h"
44 #include "llvm/ADT/SetOperations.h"
45 #include "llvm/ADT/SmallSet.h"
46 #include "llvm/ADT/SmallVector.h"
47 #include "llvm/ADT/StringSet.h"
48 #include "llvm/ADT/TypeSwitch.h"
49 #include "llvm/Support/FormatVariadic.h"
50 #include "llvm/Support/InterleavedRange.h"
51 #include "llvm/Support/LogicalResult.h"
52 #include "llvm/Support/MathExtras.h"
53 #include "llvm/Support/raw_ostream.h"
54 #include <cassert>
55 #include <optional>
56 
57 using namespace mlir;
58 using namespace mlir::linalg;
59 
60 /// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
62  int64_t dim) {
63  auto type = cast<ShapedType>(v.getType());
64  if (!type.isDynamicDim(dim))
65  return builder.getIndexAttr(type.getDimSize(dim));
66 
67  return getAsOpFoldResult(
69  .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
70  return builder.create<tensor::DimOp>(loc, v, dim);
71  })
72  .Case<MemRefType>([&](MemRefType t) -> Value {
73  return builder.create<memref::DimOp>(loc, v, dim);
74  }));
75 }
76 
77 /// Returns a memref.subview or a tensor.extract_slice based on the type of the
78 /// `source`.
79 static Operation *getSlice(OpBuilder &b, Location loc, Value source,
80  ArrayRef<OpFoldResult> offsets,
82  ArrayRef<OpFoldResult> strides) {
83  return TypeSwitch<Type, Operation *>(source.getType())
84  .Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
85  return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
86  strides);
87  })
88  .Case<MemRefType>([&](MemRefType type) -> Operation * {
89  return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
90  strides);
91  })
92  .Default([&](Type t) -> Operation * { return nullptr; });
93 }
94 
95 //===----------------------------------------------------------------------===//
96 // Helper functions
97 //===----------------------------------------------------------------------===//
98 
100  int64_t dim) {
101  if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
102  return b.createOrFold<memref::DimOp>(loc, source, dim);
103  if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
104  return b.createOrFold<tensor::DimOp>(loc, source, dim);
105  llvm_unreachable("Expected MemRefType or TensorType");
106 }
107 
109  int64_t dim) {
110  auto shapedType = llvm::cast<ShapedType>(source.getType());
111  if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
112  return createOrFoldDimOp(b, loc, source, dim);
113  return b.getIndexAttr(shapedType.getDimSize(dim));
114 }
115 
116 //===----------------------------------------------------------------------===//
117 // Support for named Linalg ops defined in ods-gen.
118 //===----------------------------------------------------------------------===//
119 
123 
124 /// Fills the region of a structured operation using the provided
125 /// `regionBuilder`. The method is used by both named structured ops created by
126 /// ods-gen and by manually defined C++ ops. It is called by both builders and
127 /// parsers and creates a block with arguments corresponding to the elemental
128 /// types of `inputTypes` and `outputTypes`.
129 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
130  TypeRange inputTypes, TypeRange outputTypes,
133  RegionBuilderFn regionBuilder) {
134  SmallVector<Type, 8> argTypes;
135  SmallVector<Location, 8> argLocs;
136  for (auto containers : {inputTypes, outputTypes}) {
137  for (auto t : containers) {
138  argTypes.push_back(
139  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
140 
141  // TODO: Pass in a proper location here.
142  argLocs.push_back(opBuilder.getUnknownLoc());
143  }
144  }
145 
146  // RAII.
147  OpBuilder::InsertionGuard guard(opBuilder);
148  Block *body =
149  opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
150 
151  opBuilder.setInsertionPointToStart(body);
152  ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
153  regionBuilder(b, *body, attrs, emitError);
154 
155  // indexing_maps is an auto-generated method.
156 
157  // iterator_types is an auto-generated method.
158 }
159 
160 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
161 /// The result types are derived automatically if `resultTensorTypes` is none.
162 /// The body of the operation is filled using `regionBuilder`. All ods-gen
163 /// created structured operations use the method to implement their builders.
165  std::optional<TypeRange> resultTensorTypes,
166  ValueRange inputs, ValueRange outputs,
167  ArrayRef<NamedAttribute> attributes,
168  RegionBuilderFn regionBuilder) {
169  // Derive the result types if needed.
170  SmallVector<Type> derivedResultTypes =
171  resultTensorTypes.value_or(TypeRange());
172  if (!resultTensorTypes)
173  copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
174  llvm::IsaPred<RankedTensorType>);
175 
176  state.addOperands(inputs);
177  state.addOperands(outputs);
178  state.addTypes(derivedResultTypes);
179 
180  state.addAttributes(attributes);
181  state.addAttribute(
182  "operandSegmentSizes",
183  b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
184  static_cast<int32_t>(outputs.size())}));
185 
186  // Create and fill the region of the structured operation.
187  Region &region = *state.addRegion();
188  fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
189  state.attributes.getAttrs(), /*emitError=*/{},
190  regionBuilder);
191 }
192 
193 static void buildMatmulOp(OpBuilder &b, OperationState &state,
194  std::optional<TypeRange> resultTensorTypes,
195  ValueRange inputs, ValueRange outputs,
196  ArrayRef<NamedAttribute> attributes,
197  RegionBuilderFn regionBuilder,
198  ArrayRef<AffineMap> indexingMaps) {
199  // Initialize indexingMaps attribute, for MatmulOp.
200  SmallVector<Attribute, 3> indexingMapsAttrVal;
201  indexingMapsAttrVal = llvm::map_to_vector(
202  MatmulOp::getDefaultIndexingMaps(b.getContext()),
203  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
204  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
205  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
206  attributes, regionBuilder);
207 }
208 
210  std::optional<TypeRange> resultTensorTypes,
211  ValueRange inputs, ValueRange outputs,
212  ArrayRef<NamedAttribute> attributes,
213  RegionBuilderFn regionBuilder,
214  ArrayRef<AffineMap> indexingMaps) {
215  // Initialize indexingMaps attribute, for BatchMatmulOp.
216  SmallVector<Attribute, 4> indexingMapsAttrVal;
217  indexingMapsAttrVal =
218  llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
219  return AffineMapAttr::get(map);
220  });
221  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
222  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
223  attributes, regionBuilder);
224 }
225 
227  std::optional<TypeRange> resultTensorTypes,
228  ValueRange inputs, ValueRange outputs,
229  ArrayRef<NamedAttribute> attributes,
230  RegionBuilderFn regionBuilder,
231  ArrayRef<AffineMap> indexingMaps) {
232  // Initialize indexingMaps attribute, for BatchReduceMatmulOp.
233  SmallVector<Attribute, 4> indexingMapsAttrVal;
234  indexingMapsAttrVal =
235  llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
236  return AffineMapAttr::get(map);
237  });
238  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
239  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
240  attributes, regionBuilder);
241 }
242 
243 /// Common parsing used for both named structured ops created by ods-gen and by
244 /// manually defined C++ ops. Does not handle regions.
245 static ParseResult
247  SmallVectorImpl<Type> &inputTypes,
248  SmallVectorImpl<Type> &outputTypes,
249  bool addOperandSegmentSizes = true) {
250  SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
252  outputsOperands;
253 
254  if (succeeded(parser.parseOptionalLess())) {
255  if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
256  return failure();
257  }
258  attrsLoc = parser.getCurrentLocation();
259  if (parser.parseOptionalAttrDict(result.attributes))
260  return failure();
261 
262  if (succeeded(parser.parseOptionalKeyword("ins"))) {
263  if (parser.parseLParen())
264  return failure();
265 
266  inputsOperandsLoc = parser.getCurrentLocation();
267  if (parser.parseOperandList(inputsOperands) ||
268  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
269  return failure();
270  }
271 
272  if (succeeded(parser.parseOptionalKeyword("outs"))) {
273  outputsOperandsLoc = parser.getCurrentLocation();
274  if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
275  parser.parseColonTypeList(outputTypes) || parser.parseRParen())
276  return failure();
277  }
278 
279  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
280  result.operands) ||
281  parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
282  result.operands))
283  return failure();
284 
285  if (addOperandSegmentSizes) {
286  // This is a bit complex because we're trying to be backward compatible with
287  // operation syntax that mix the inherent attributes and the discardable
288  // ones in the same dictionary. If the properties are used, we append the
289  // operandSegmentSizes there directly. Otherwise we append it to the
290  // discardable attributes dictionary where it is handled by the generic
291  // Operation::create(...) method.
292  if (result.propertiesAttr) {
293  NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
294  attrs.append("operandSegmentSizes",
296  {static_cast<int32_t>(inputsOperands.size()),
297  static_cast<int32_t>(outputsOperands.size())}));
298  result.propertiesAttr = attrs.getDictionary(parser.getContext());
299  } else {
300  result.addAttribute("operandSegmentSizes",
302  {static_cast<int32_t>(inputsOperands.size()),
303  static_cast<int32_t>(outputsOperands.size())}));
304  }
305  }
306  if (!result.propertiesAttr) {
307  std::optional<RegisteredOperationName> info =
308  result.name.getRegisteredInfo();
309  if (info) {
310  if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
311  return parser.emitError(attrsLoc)
312  << "'" << result.name.getStringRef() << "' op ";
313  })))
314  return failure();
315  }
316  }
317  return success();
318 }
319 
321  ValueRange outputs) {
322  if (!inputs.empty())
323  p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
324  if (!outputs.empty())
325  p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
326 }
327 
328 //===----------------------------------------------------------------------===//
329 // Specific parsing and printing for named structured ops created by ods-gen.
330 //===----------------------------------------------------------------------===//
331 
332 static ParseResult parseNamedStructuredOpRegion(
333  OpAsmParser &parser, Region &region, unsigned numRegionArgs,
334  TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
335  RegionBuilderFn regionBuilder, SMLoc loc) {
336  if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
337  return parser.emitError(
338  parser.getCurrentLocation(),
339  llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
340  "region expects {0} args, got {1}",
341  numRegionArgs, inputTypes.size() + outputTypes.size()));
342  }
343 
344  OpBuilder opBuilder(parser.getContext());
345  ParseResult result = success();
347  opBuilder, region, inputTypes, outputTypes, attrs,
348  [&]() {
349  result = failure();
350  return parser.emitError(loc);
351  },
352  regionBuilder);
353  return result;
354 }
355 
356 static ParseResult
358  SmallVectorImpl<Type> &resultTypes) {
359  if (parser.parseOptionalArrowTypeList(resultTypes))
360  return failure();
361  return success();
362 }
363 
364 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
365  OperationState &result,
366  unsigned numRegionArgs,
367  RegionBuilderFn regionBuilder) {
368  // TODO: Enable when ods-gen supports captures.
369  SmallVector<Type, 1> inputTypes, outputTypes;
370  SMLoc loc = parser.getCurrentLocation();
371  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
372  return failure();
373 
374  // Parse optional attributes.
375  if (parser.parseOptionalAttrDict(result.attributes))
376  return failure();
377 
378  // TODO: consider merging results parsing into region parsing.
379  // Need to wait for declarative assembly resolution to decide.
380  SmallVector<Type, 1> outputTensorsTypes;
381  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
382  return failure();
383  result.addTypes(outputTensorsTypes);
384 
385  std::unique_ptr<Region> region = std::make_unique<Region>();
386  if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
387  outputTypes, result.attributes.getAttrs(),
388  regionBuilder, loc))
389  return failure();
390  result.addRegion(std::move(region));
391 
392  return success();
393 }
394 
396  TypeRange resultTypes) {
397  if (resultTypes.empty())
398  return;
399  p.printOptionalArrowTypeList(resultTypes);
400 }
401 
403  ValueRange inputs, ValueRange outputs,
404  ArrayRef<StringRef> elidedAttrs = {}) {
405  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
406 
407  // Printing is shared with generic ops, except for the region and
408  // attributes.
409  printCommonStructuredOpParts(p, inputs, outputs);
410 
411  // Results printing.
413 
414  // Region is elided.
415 }
416 
417 //===----------------------------------------------------------------------===//
418 // Region builder helper.
419 // TODO: Move this to a utility library.
420 // The public methods on this class are referenced directly from generated code.
421 // Helper build the unary, binary, and type conversion functions defined by the
422 // DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
423 // class.
424 //
425 // Implementations of the math functions must be polymorphic over numeric types,
426 // internally performing necessary casts. If the function application makes no
427 // sense, then the only recourse is to assert and return nullptr. This can be
428 // extended later if it becomes possible to fail construction of the region. The
429 // invariant should be enforced at a higher level.
430 //
431 // TODO: These helpers are currently type polymorphic over the class of integer
432 // and floating point types, but they will not internally cast within bit
433 // widths of a class (mixed precision such as i8->i32) or across classes
434 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
435 // to be handled with care and work is being considered to extend the op
436 // language to make such cases explicit. In the mean-time, violating this will
437 // fail verification, which is deemed acceptable.
438 //===----------------------------------------------------------------------===//
439 
440 namespace {
441 
442 class RegionBuilderHelper {
443 public:
444  RegionBuilderHelper(OpBuilder &builder, Block &block)
445  : builder(builder), block(block) {}
446 
447  // Build the unary functions defined by OpDSL.
448  Value buildUnaryFn(UnaryFn unaryFn, Value arg,
450  if (!isFloatingPoint(arg)) {
451  if (emitError) {
452  emitError() << "unsupported non numeric type";
453  return nullptr;
454  }
455  llvm_unreachable("unsupported non numeric type");
456  }
457  OpBuilder::InsertionGuard g(builder);
458  builder.setInsertionPointToEnd(&block);
459  switch (unaryFn) {
460  case UnaryFn::exp:
461  return builder.create<math::ExpOp>(arg.getLoc(), arg);
462  case UnaryFn::log:
463  return builder.create<math::LogOp>(arg.getLoc(), arg);
464  case UnaryFn::abs:
465  return builder.create<math::AbsFOp>(arg.getLoc(), arg);
466  case UnaryFn::ceil:
467  return builder.create<math::CeilOp>(arg.getLoc(), arg);
468  case UnaryFn::floor:
469  return builder.create<math::FloorOp>(arg.getLoc(), arg);
470  case UnaryFn::negf:
471  return builder.create<arith::NegFOp>(arg.getLoc(), arg);
472  case UnaryFn::reciprocal: {
473  Attribute oneAttr = builder.getOneAttr(arg.getType());
474  auto one = builder.create<arith::ConstantOp>(arg.getLoc(),
475  ::cast<TypedAttr>(oneAttr));
476  return builder.create<arith::DivFOp>(arg.getLoc(), one, arg);
477  }
478  case UnaryFn::round:
479  return builder.create<math::RoundOp>(arg.getLoc(), arg);
480  case UnaryFn::sqrt:
481  return builder.create<math::SqrtOp>(arg.getLoc(), arg);
482  case UnaryFn::rsqrt:
483  return builder.create<math::RsqrtOp>(arg.getLoc(), arg);
484  case UnaryFn::square:
485  return builder.create<arith::MulFOp>(arg.getLoc(), arg, arg);
486  case UnaryFn::tanh:
487  return builder.create<math::TanhOp>(arg.getLoc(), arg);
488  case UnaryFn::erf:
489  return builder.create<math::ErfOp>(arg.getLoc(), arg);
490  }
491  if (emitError) {
492  emitError() << "unsupported unary function";
493  return nullptr;
494  }
495  llvm_unreachable("unsupported unary function");
496  }
497 
498  // Build the binary functions defined by OpDSL.
499  // If emitError is provided, an error will be emitted if the operation is not
500  // supported and a nullptr will be returned, otherwise an assertion will be
501  // raised.
502  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
504  bool allComplex = isComplex(arg0) && isComplex(arg1);
505  bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
506  bool allInteger = isInteger(arg0) && isInteger(arg1);
507  bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
508  arg1.getType().getIntOrFloatBitWidth() == 1;
509  if (!allComplex && !allFloatingPoint && !allInteger) {
510  if (emitError) {
511  emitError()
512  << "Cannot build binary Linalg operation: expects allComplex, "
513  "allFloatingPoint, or allInteger, got "
514  << arg0.getType() << " and " << arg1.getType();
515  return nullptr;
516  }
517  llvm_unreachable("unsupported non numeric type");
518  }
519  OpBuilder::InsertionGuard g(builder);
520  builder.setInsertionPointToEnd(&block);
521  switch (binaryFn) {
522  case BinaryFn::add:
523  if (allComplex)
524  return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
525  if (allFloatingPoint)
526  return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
527  if (allBool)
528  return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
529  return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
530  case BinaryFn::sub:
531  if (allComplex)
532  return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
533  if (allFloatingPoint)
534  return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
535  if (allBool) {
536  if (emitError) {
537  emitError() << "unsupported operation: sub with bools";
538  return nullptr;
539  }
540  llvm_unreachable("unsupported operation: sub with bools");
541  }
542  return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
543  case BinaryFn::mul:
544  if (allComplex)
545  return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
546  if (allFloatingPoint)
547  return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
548  if (allBool)
549  return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
550  return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
551  case BinaryFn::div:
552  if (allComplex)
553  return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
554  if (allFloatingPoint)
555  return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
556  if (allBool) {
557  if (emitError) {
558  emitError() << "unsupported operation: div with bools";
559  return nullptr;
560  }
561  llvm_unreachable("unsupported operation: div with bools");
562  }
563  return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
564  case BinaryFn::div_unsigned:
565  if (!allInteger || allBool) {
566  if (emitError) {
567  emitError() << "unsupported operation: unsigned div not on uint";
568  return nullptr;
569  }
570  llvm_unreachable("unsupported operation: unsigned div not on uint");
571  }
572  return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
573  case BinaryFn::max_signed:
574  assert(!allComplex);
575  if (allFloatingPoint)
576  return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
577  return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
578  case BinaryFn::min_signed:
579  assert(!allComplex);
580  if (allFloatingPoint)
581  return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
582  return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
583  case BinaryFn::max_unsigned:
584  assert(!allComplex);
585  if (allFloatingPoint)
586  return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
587  return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
588  case BinaryFn::min_unsigned:
589  assert(!allComplex);
590  if (allFloatingPoint)
591  return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
592  return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
593  case BinaryFn::powf:
594  assert(allFloatingPoint);
595  return builder.create<math::PowFOp>(arg0.getLoc(), arg0, arg1);
596  }
597  if (emitError) {
598  emitError() << "unsupported binary function";
599  return nullptr;
600  }
601  llvm_unreachable("unsupported binary function");
602  }
603 
604  // Build the ternary functions defined by OpDSL.
605  Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
607  bool headBool =
608  isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
609  bool tailFloatingPoint =
610  isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
611  bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
612  OpBuilder::InsertionGuard g(builder);
613  builder.setInsertionPointToEnd(&block);
614  switch (ternaryFn) {
615  case TernaryFn::select:
616  if (!headBool && !(tailFloatingPoint || tailInteger))
617  llvm_unreachable("unsupported non numeric type");
618  return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
619  }
620  if (emitError) {
621  emitError() << "unsupported ternary function";
622  return nullptr;
623  }
624  llvm_unreachable("unsupported ternary function");
625  }
626 
627  // Build the type functions defined by OpDSL.
628  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
630  switch (typeFn) {
631  case TypeFn::cast_signed:
632  return cast(toType, operand, false);
633  case TypeFn::cast_unsigned:
634  return cast(toType, operand, true);
635  }
636  if (emitError) {
637  emitError() << "unsupported type conversion function";
638  return nullptr;
639  }
640  llvm_unreachable("unsupported type conversion function");
641  }
642 
643  void yieldOutputs(ValueRange values) {
644  OpBuilder::InsertionGuard g(builder);
645  builder.setInsertionPointToEnd(&block);
646  Location loc = builder.getUnknownLoc();
647  builder.create<YieldOp>(loc, values);
648  }
649 
650  Value constant(const std::string &value) {
651  OpBuilder::InsertionGuard g(builder);
652  builder.setInsertionPointToEnd(&block);
653  Location loc = builder.getUnknownLoc();
654  Attribute valueAttr = parseAttribute(value, builder.getContext());
655  return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
656  }
657 
658  Value index(int64_t dim) {
659  OpBuilder::InsertionGuard g(builder);
660  builder.setInsertionPointToEnd(&block);
661  return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
662  }
663 
664  Type getIntegerType(unsigned width) {
665  return IntegerType::get(builder.getContext(), width);
666  }
667 
668  Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
669  Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
670 
671 private:
672  // Generates operations to cast the given operand to a specified type.
673  // If the cast cannot be performed, a warning will be issued and the
674  // operand returned as-is (which will presumably yield a verification
675  // issue downstream).
676  Value cast(Type toType, Value operand, bool isUnsignedCast) {
677  OpBuilder::InsertionGuard g(builder);
678  builder.setInsertionPointToEnd(&block);
679  auto loc = operand.getLoc();
680  if (isa<UnknownLoc>(loc)) {
681  if (operand.getDefiningOp())
682  loc = operand.getDefiningOp()->getLoc();
683  else if (operand.getParentBlock() &&
684  operand.getParentBlock()->getParentOp())
685  loc = operand.getParentBlock()->getParentOp()->getLoc();
686  }
687  return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
688  }
689 
690  bool isComplex(Value value) {
691  return llvm::isa<ComplexType>(value.getType());
692  }
693  bool isFloatingPoint(Value value) {
694  return llvm::isa<FloatType>(value.getType());
695  }
696  bool isInteger(Value value) {
697  return llvm::isa<IntegerType>(value.getType());
698  }
699 
700  OpBuilder &builder;
701  Block &block;
702 };
703 
704 } // namespace
705 
706 //===----------------------------------------------------------------------===//
707 // CopyOp
708 //===----------------------------------------------------------------------===//
709 
710 namespace {
711 
712 struct EraseSelfCopy : OpRewritePattern<CopyOp> {
714  LogicalResult matchAndRewrite(CopyOp copyOp,
715  PatternRewriter &rewriter) const override {
716  if (copyOp.getInputs() != copyOp.getOutputs())
717  return rewriter.notifyMatchFailure(copyOp, "not a self copy");
718  if (copyOp.hasPureBufferSemantics())
719  rewriter.eraseOp(copyOp);
720  else
721  rewriter.replaceOp(copyOp, copyOp.getInputs());
722 
723  return success();
724  }
725 };
726 
727 } // namespace
728 
729 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
730  MLIRContext *context) {
731  results.add<EraseSelfCopy>(context);
732 }
733 
734 //===----------------------------------------------------------------------===//
735 // FillOp
736 //===----------------------------------------------------------------------===//
737 
738 namespace {
739 
740 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
741 ///
742 /// For such op chains, we can create new linalg.fill ops with the result
743 /// type of the tensor.expand/collapse_shape op.
744 template <typename TensorReshapeOp>
745 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
747  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
748  PatternRewriter &rewriter) const override {
749  auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
750  if (!oldFill)
751  return failure();
752 
753  Location loc = oldFill.getLoc();
754  TensorReshapeOp newInit;
755  if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
756 
757  newInit = rewriter.create<TensorReshapeOp>(
758  loc, reshapeOp.getResultType(), oldFill.output(),
759  reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
760  reshapeOp.getStaticOutputShape());
761  } else {
762  newInit = rewriter.create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
763  oldFill.output(),
764  reshapeOp.getReassociation());
765  }
766  rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
767  ValueRange{newInit});
768  return success();
769  }
770 };
771 
772 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
773 /// filling value are the same.
774 struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
776 
777  LogicalResult matchAndRewrite(tensor::PadOp padOp,
778  PatternRewriter &rewriter) const override {
779  auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
780  if (!fillOp)
781  return failure();
782 
783  // We can only fold if the padding value is the same as the original
784  // filling value.
785  Value padValue = padOp.getConstantPaddingValue();
786  if (!padValue || fillOp.value() != padValue)
787  return failure();
788 
789  ReifiedRankedShapedTypeDims reifiedShape;
790  if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
791  return rewriter.notifyMatchFailure(
792  padOp, "failed to reify tensor.pad op result shape");
793 
794  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
795  padOp.getLoc(), reifiedShape.front(),
796  padOp.getResultType().getElementType());
797  Value replacement =
798  rewriter
799  .create<FillOp>(fillOp.getLoc(), ValueRange{padValue},
800  ValueRange{emptyTensor})
801  .getResult(0);
802  if (replacement.getType() != padOp.getResultType()) {
803  replacement = rewriter.create<tensor::CastOp>(
804  fillOp.getLoc(), padOp.getResultType(), replacement);
805  }
806  rewriter.replaceOp(padOp, replacement);
807  return success();
808  }
809 };
810 
811 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
812 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
813 /// filling value are the same.
814 struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
816 
817  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
818  PatternRewriter &rewriter) const override {
819  auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
820  if (!srcPadOp)
821  return failure();
822 
823  if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
824  return failure();
825 
826  // Walk back the tensor.insert_slice chain and find the first destination
827  // value at the start of the chain.
828  Value firstDest = insertOp.getDest();
829  while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
830  if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
831  return failure();
832 
833  // Make sure the range of values accessed are disjoint. Without this, we
834  // cannot fold tensor.pad away.
835  bool disjoint = false;
836  for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
837  // If the dimension has dynamic offset/size, we cannot guarantee
838  // disjoint. So just skip it.
839  if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
840  insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
841  prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
842  continue;
843 
844  // Get the range start and end, inclusively for both.
845  int64_t prevStart = prevOp.getStaticOffset(i);
846  int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
847  prevOp.getStaticStride(i);
848  int64_t nextStart = insertOp.getStaticOffset(i);
849  int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
850  insertOp.getStaticStride(i);
851  if (prevEnd < nextStart || nextEnd < prevStart) {
852  disjoint = true;
853  break;
854  }
855  }
856 
857  if (!disjoint)
858  break;
859  firstDest = prevOp.getDest();
860  }
861 
862  // Check whether the first destination is a fill op. For overlapped cases,
863  // this also cannot be true.
864  auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
865  if (!dstFillOp)
866  return failure();
867 
868  // We can only fold if the padding value is the same as the original
869  // filling value.
870  Value padValue = srcPadOp.getConstantPaddingValue();
871  if (!padValue || dstFillOp.value() != padValue)
872  return failure();
873 
874  SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
875  SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
876 
877  Location loc = insertOp.getLoc();
878  MLIRContext *context = getContext();
879 
880  AffineExpr sym0, sym1;
881  bindSymbols(context, sym0, sym1);
882  auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
883 
884  // Calculate the new offsets for the insert. It should be the old offsets
885  // plus low padding sizes.
886  SmallVector<OpFoldResult, 4> newOffsets;
887  for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
888  newOffsets.push_back(affine::makeComposedFoldedAffineApply(
889  rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
890  }
891 
892  RankedTensorType srcPadType = srcPadOp.getSourceType();
894  for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
895  if (srcPadType.isDynamicDim(i)) {
896  newSizes.push_back(
897  rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
898  .getResult());
899  } else {
900  newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
901  }
902  }
903 
904  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
905  insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
906  newSizes, insertOp.getMixedStrides());
907  return success();
908  }
909 };
910 
911 /// Fold tensor.extract(linalg.fill(<input>)) into <input>
912 struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
913 public:
915 
916  LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
917  PatternRewriter &rewriter) const override {
918  // See if tensor input of tensor.extract op is the result of a linalg.fill
919  // op.
920  auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
921  if (!fillOp)
922  return failure();
923 
924  // Get scalar input operand of linalg.fill op.
925  Value extractedScalar = fillOp.getInputs()[0];
926 
927  // Replace tensor.extract op with scalar value used to fill the tensor.
928  rewriter.replaceOp(extractOp, extractedScalar);
929  return success();
930  }
931 };
932 
933 /// Folds pack(fill) into a single fill op if
934 /// 1. The pack op does not have padding value, or
935 /// 2. The filled value and padding value are the same.
936 static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
937  linalg::PackOp packOp) {
938  auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
939  if (!fillOp)
940  return failure();
941 
942  if (auto paddingValue = packOp.getPaddingValue())
943  if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
944  return failure();
945 
946  Value packOpDest = packOp.getDest();
947  if (!packOpDest.hasOneUse())
948  return failure();
949 
950  return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
951  packOp.getDest());
952 }
953 
954 /// Wrapper pattern that applies foldFillPackIntoFillOp method.
955 struct FoldFillWithPack : public OpRewritePattern<linalg::PackOp> {
956 public:
957  FoldFillWithPack(MLIRContext *context)
958  : OpRewritePattern<linalg::PackOp>(context) {}
959 
960  LogicalResult matchAndRewrite(linalg::PackOp packOp,
961  PatternRewriter &rewriter) const override {
962  auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
963  if (failed(fillOp))
964  return failure();
965  rewriter.replaceOp(packOp, fillOp.value().result());
966  return success();
967  }
968 };
969 
970 /// Fold fill with copy.
971 struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
973 
974  LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
975  PatternRewriter &rewriter) const override {
976  if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
977  rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
978  fillOp.getInputs(),
979  copyOp.getOutputs());
980  return success();
981  }
982  if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
983  rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
984  fillOp.getOutputs());
985  return success();
986  }
987  return failure();
988  }
989 };
990 
991 /// Fold fill with transpose.
992 struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
994 
995  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
996  PatternRewriter &rewriter) const override {
997  if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
998  rewriter.replaceOpWithNewOp<FillOp>(
999  transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
1000  transposeOp.getDpsInitOperand(0)->get());
1001  return success();
1002  }
1003  return failure();
1004  }
1005 };
1006 
1007 /// Fold a concat with all elements being fills of the same value
1008 /// into a fill of the concat result shape.
1009 struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
1011 
1012  LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
1013  PatternRewriter &rewriter) const override {
1014  auto concatOperands = concatOp.getInputs();
1015  if (concatOperands.empty()) {
1016  return failure();
1017  }
1018 
1019  auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
1020  if (!firstFillOp) {
1021  return failure();
1022  }
1023  // Prefetch the fill value.
1024  OpFoldResult firstFillVal =
1025  getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
1026  // Collect all the outs values for the fill operations.
1027  SmallVector<Value> allOuts;
1028  allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
1029 
1030  auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
1031  auto fillOp = v.getDefiningOp<linalg::FillOp>();
1032  if (!fillOp) {
1033  return false;
1034  }
1035 
1036  OpFoldResult fillVal =
1037  getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
1038  if (fillVal != firstFillVal)
1039  return false;
1040 
1041  allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
1042  return true;
1043  };
1044  if (!llvm::all_of(concatOperands.drop_front(),
1045  isDefinedByCompatibleFillOp)) {
1046  return rewriter.notifyMatchFailure(
1047  concatOp, "not all operands are defined by a compatible fill op");
1048  }
1049 
1050  Value outsConcat = rewriter.create<tensor::ConcatOp>(
1051  concatOp.getLoc(), concatOp.getDim(), allOuts);
1052  rewriter.replaceOpWithNewOp<linalg::FillOp>(
1053  concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
1054  return success();
1055  }
1056 };
1057 
1058 } // namespace
1059 
1060 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
1061  MLIRContext *context) {
1062  results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
1063  FoldFillWithPack, FoldFillWithPad,
1064  FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
1065  FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
1066  FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1067 }
1068 
1069 //===----------------------------------------------------------------------===//
1070 // GenericOp
1071 //===----------------------------------------------------------------------===//
1072 
1074  OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
1075  ValueRange outputs,
1076  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
1077  SmallVector<Type, 4> blockArgTypes;
1078  SmallVector<Location, 4> blockArgLocs;
1079  for (ValueRange container : {inputs, outputs}) {
1080  for (Value v : container) {
1081  Type t = v.getType();
1082  blockArgTypes.push_back(
1083  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
1084  blockArgLocs.push_back(v.getLoc());
1085  }
1086  }
1087 
1088  OpBuilder::InsertionGuard guard(builder);
1089  Block *bodyBlock =
1090  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1091  bodyBuild(builder, loc, bodyBlock->getArguments());
1092 }
1093 
1094 void GenericOp::getAsmBlockArgumentNames(Region &region,
1095  OpAsmSetValueNameFn setNameFn) {
1096  for (Value v : getRegionInputArgs())
1097  setNameFn(v, "in");
1098  for (Value v : getRegionOutputArgs())
1099  setNameFn(v, "out");
1100 }
1101 
1102 void GenericOp::build(
1103  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1104  ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
1105  ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1106  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1107  ArrayRef<NamedAttribute> attributes) {
1108  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1109  iteratorTypes, doc, libraryCall);
1110  result.addAttributes(attributes);
1111  if (bodyBuild)
1112  buildGenericRegion(builder, result.location, *result.regions.front(),
1113  inputs, outputs, bodyBuild);
1114 }
1115 
1116 void GenericOp::build(
1117  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1118  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1119  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1120  StringRef libraryCall,
1121  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1122  ArrayRef<NamedAttribute> attributes) {
1123  build(builder, result, resultTensorTypes, inputs, outputs,
1124  builder.getAffineMapArrayAttr(indexingMaps),
1125  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
1126  iteratorTypes,
1127  [&](utils::IteratorType iter) -> mlir::Attribute {
1128  return IteratorTypeAttr::get(builder.getContext(), iter);
1129  }))),
1130  doc.empty() ? StringAttr() : builder.getStringAttr(doc),
1131  libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
1132  bodyBuild, attributes);
1133 }
1134 
1135 void GenericOp::build(
1136  OpBuilder &builder, OperationState &result, ValueRange inputs,
1137  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1138  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1139  StringRef libraryCall,
1140  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1141  ArrayRef<NamedAttribute> attributes) {
1142  build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
1143  iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1144 }
1145 
1146 void GenericOp::build(
1147  OpBuilder &builder, OperationState &result, ValueRange inputs,
1148  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1149  ArrayRef<utils::IteratorType> iteratorTypes,
1150  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1151  ArrayRef<NamedAttribute> attributes) {
1152  build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1153  /*doc=*/"",
1154  /*libraryCall=*/"", bodyBuild, attributes);
1155 }
1156 
1157 void GenericOp::build(
1158  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1159  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1160  ArrayRef<utils::IteratorType> iteratorTypes,
1161  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1162  ArrayRef<NamedAttribute> attributes) {
1163  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1164  iteratorTypes,
1165  /*doc=*/"",
1166  /*libraryCall=*/"", bodyBuild, attributes);
1167 }
1168 
1169 void GenericOp::print(OpAsmPrinter &p) {
1170  p << " ";
1171 
1172  // Print extra attributes.
1173  auto genericAttrNames = linalgTraitAttrNames();
1174 
1175  llvm::StringSet<> genericAttrNamesSet;
1176  genericAttrNamesSet.insert_range(genericAttrNames);
1177  SmallVector<NamedAttribute, 8> genericAttrs;
1178  for (auto attr : (*this)->getAttrs()) {
1179  if (attr.getName() == getIteratorTypesAttrName()) {
1180  auto iteratorTypes =
1181  llvm::cast<ArrayAttr>(attr.getValue())
1182  .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1183  // Convert IteratorType enums into the string representation. This is
1184  // needed, because tests still use the old format when 'iterator_types'
1185  // attribute is represented as an array of strings.
1186  // TODO: Remove this conversion once tests are fixed.
1187  SmallVector<Attribute> iteratorTypeNames =
1188  llvm::to_vector(llvm::map_range(
1189  iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1190  return StringAttr::get(getContext(), stringifyIteratorType(t));
1191  }));
1192 
1193  genericAttrs.emplace_back(
1194  getIteratorTypesAttrName(),
1195  ArrayAttr::get(getContext(), iteratorTypeNames));
1196  } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1197  genericAttrs.push_back(attr);
1198  }
1199  }
1200  if (!genericAttrs.empty()) {
1201  auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
1202  p << genericDictAttr;
1203  }
1204 
1205  // Printing is shared with named ops, except for the region and attributes
1206  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1207 
1208  genericAttrNames.push_back("operandSegmentSizes");
1209  genericAttrNamesSet.insert(genericAttrNames.back());
1210 
1211  bool hasExtraAttrs = false;
1212  for (NamedAttribute n : (*this)->getAttrs()) {
1213  if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1214  break;
1215  }
1216  if (hasExtraAttrs) {
1217  p << " attrs = ";
1218  p.printOptionalAttrDict((*this)->getAttrs(),
1219  /*elidedAttrs=*/genericAttrNames);
1220  }
1221 
1222  // Print region.
1223  if (!getRegion().empty()) {
1224  p << ' ';
1225  p.printRegion(getRegion());
1226  }
1227 
1228  // Print results.
1229  printNamedStructuredOpResults(p, getResultTensors().getTypes());
1230 }
1231 
1232 ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1233  DictionaryAttr dictAttr;
1234  // Parse the core linalg traits that must check into a dictAttr.
1235  // The name is unimportant as we will overwrite result.attributes.
1236  // The core linalg traits must contain the information necessary to pass the
1237  // verifier.
1238  llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1239  if (parser.parseAttribute(dictAttr, "_", result.attributes))
1240  return failure();
1241  result.attributes.assign(dictAttr.getValue().begin(),
1242  dictAttr.getValue().end());
1243 
1244  // Convert array of string into an array of IteratorType enums. This is
1245  // needed, because tests still use the old format when 'iterator_types'
1246  // attribute is represented as an array of strings.
1247  // TODO: Remove this conversion once tests are fixed.
1248  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1249  result.attributes.get(getIteratorTypesAttrName(result.name)));
1250  if (!iteratorTypes) {
1251  return parser.emitError(attributeLocation)
1252  << "expected " << getIteratorTypesAttrName(result.name)
1253  << " array attribute";
1254  }
1255 
1256  SmallVector<Attribute> iteratorTypeAttrs;
1257 
1258  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1259  auto maybeIteratorType = utils::symbolizeIteratorType(s);
1260  if (!maybeIteratorType.has_value())
1261  return parser.emitError(parser.getCurrentLocation())
1262  << "unexpected iterator_type (" << s << ")";
1263 
1264  iteratorTypeAttrs.push_back(
1265  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1266  }
1267  result.attributes.set(getIteratorTypesAttrName(result.name),
1268  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1269 
1270  // Parsing is shared with named ops, except for the region.
1271  SmallVector<Type, 1> inputTypes, outputTypes;
1272  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1273  return failure();
1274 
1275  // Optional attributes may be added.
1276  if (succeeded(parser.parseOptionalKeyword("attrs")))
1277  if (failed(parser.parseEqual()) ||
1278  failed(parser.parseOptionalAttrDict(result.attributes)))
1279  return failure();
1280 
1281  std::unique_ptr<Region> region = std::make_unique<Region>();
1282  if (parser.parseRegion(*region, {}))
1283  return failure();
1284  result.addRegion(std::move(region));
1285 
1286  // Generic ops may specify that a subset of its outputs are tensors. Such
1287  // outputs are specified in the result type.
1288  // TODO: may need to move output parsing before region parsing.
1289  // Need to wait for declarative assembly resolution to decide.
1290  SmallVector<Type, 1> outputTensorsTypes;
1291  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1292  return failure();
1293  result.addTypes(outputTensorsTypes);
1294 
1295  return success();
1296 }
1297 
1300  &effects,
1301  LinalgOp linalgOp) {
1302  for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1303  if (!llvm::isa<MemRefType>(operand.getType()))
1304  continue;
1305  effects.emplace_back(
1306  MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
1307  /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1308  }
1309 
1310  for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1311  if (!llvm::isa<MemRefType>(operand.get().getType()))
1312  continue;
1313  if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1314  effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
1315  /*effectOnFullRegion=*/true,
1317  }
1318  effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
1319  /*effectOnFullRegion=*/true,
1321  }
1322 }
1323 
1324 void GenericOp::getEffects(
1326  &effects) {
1327  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1328 }
1329 
1331 getGenericSpeculatabilityImpl(LinalgOp linalgOp) {
1332  // Operands with value semantics are speculatable, while operands with memory
1333  // semantics are not.
1334  if (!linalgOp.hasPureTensorSemantics())
1336  // The body of the op can still have speculation in its region.
1338 }
1339 
1340 Speculation::Speculatability GenericOp::getSpeculatability() {
1341  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1342 }
1343 
1344 LogicalResult GenericOp::verify() { return success(); }
1345 
1346 namespace {
1347 
1348 /// Remove linalg operations that are just copying the values from inputs to
1349 /// results. In the memref case, the operation must be copying to and from the
1350 /// same value. Requirements are:
1351 /// 1) All iterator types are parallel
1352 /// 2) The body contains just a yield operation with the yielded values being
1353 /// the arguments corresponding to the operands.
1354 template <typename OpTy>
1355 struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1357 
1358  LogicalResult matchAndRewrite(OpTy linalgOp,
1359  PatternRewriter &rewriter) const override {
1360  // All indexing maps must be equal. It follows that they are permutations.
1361  if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1362  return failure();
1363 
1364  // Check that the body of the linalg operation is just a linalg.yield
1365  // operation.
1366  Block &body = linalgOp->getRegion(0).front();
1367  if (!llvm::hasSingleElement(body))
1368  return failure();
1369  auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1370  if (!yieldOp)
1371  return failure();
1372 
1373  // In the buffer case, we need to check exact buffer equality.
1374  if (linalgOp.hasPureBufferSemantics()) {
1375  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1376  linalgOp.getDpsInputOperand(0)->get() !=
1377  linalgOp.getDpsInitOperand(0)->get()) {
1378  return rewriter.notifyMatchFailure(
1379  linalgOp, "expected single input and output to be the same value");
1380  }
1381 
1382  auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1383  if (!yieldArg || yieldArg.getOwner() != &body) {
1384  return rewriter.notifyMatchFailure(linalgOp,
1385  "cannot fold fill-like op");
1386  }
1387 
1388  rewriter.eraseOp(linalgOp);
1389  return success();
1390  }
1391 
1392  if (!linalgOp.hasPureTensorSemantics()) {
1393  return rewriter.notifyMatchFailure(
1394  linalgOp, "mixed semantics is not supported yet");
1395  }
1396 
1397  // Get the argument number of the returned values. That is the operand
1398  // number to use for replacing uses of this operation.
1399  SmallVector<Value> returnedArgs;
1400  for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1401  auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1402  if (!yieldArg || yieldArg.getOwner() != &body)
1403  return failure();
1404  unsigned argumentNumber = yieldArg.getArgNumber();
1405  Value returnedArg = linalgOp->getOperand(argumentNumber);
1406  Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1407  // The input can have a different type than the result, e.g. a dynamic
1408  // input dimension can be turned into a static output dimension.
1409  Type returnType = returnedArg.getType();
1410  if (returnType != resultType) {
1411  // Distinguish between sparse conversion or dense tensor casting.
1412  // TODO: unify the two ops?
1413  if (sparse_tensor::getSparseTensorEncoding(returnType) ||
1415  returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
1416  linalgOp.getLoc(), resultType, returnedArg);
1417  else {
1418  if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1419  resultType))
1420  return failure();
1421  returnedArg = rewriter.create<tensor::CastOp>(
1422  linalgOp.getLoc(), resultType, returnedArg);
1423  }
1424  }
1425  returnedArgs.push_back(returnedArg);
1426  }
1427 
1428  if (returnedArgs.size() != linalgOp->getNumResults())
1429  return failure();
1430  rewriter.replaceOp(linalgOp, returnedArgs);
1431  return success();
1432  }
1433 };
1434 
1435 } // namespace
1436 
1437 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1438  MLIRContext *context) {
1439  results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1440 }
1441 
1442 LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1443  return memref::foldMemRefCast(*this);
1444 }
1445 
1446 //===----------------------------------------------------------------------===//
1447 // MapOp
1448 //===----------------------------------------------------------------------===//
1449 
1450 static ParseResult parseDstStyleOp(
1451  OpAsmParser &parser, OperationState &result,
1452  function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1453  nullptr) {
1454  // Parse `ins` and `outs`.
1455  SmallVector<Type, 4> inputTypes, outputTypes;
1456  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1457  /*addOperandSegmentSizes=*/false))
1458  return failure();
1459 
1460  // Add result types.
1461  for (Type outputType : outputTypes) {
1462  if (llvm::isa<RankedTensorType>(outputType))
1463  result.addTypes(outputType);
1464  }
1465 
1466  // Parse required attributes.
1467  if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1468  return failure();
1469 
1470  // Parse optional attributes.
1471  if (parser.parseOptionalAttrDict(result.attributes))
1472  return failure();
1473  return success();
1474 }
1475 
1476 void MapOp::getAsmBlockArgumentNames(Region &region,
1477  OpAsmSetValueNameFn setNameFn) {
1478  for (Value v : getRegionInputArgs())
1479  setNameFn(v, "in");
1480 }
1481 
1482 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1483  if (!getResults().empty())
1484  setNameFn(getResults().front(), "mapped");
1485 }
1486 
1487 void MapOp::build(
1488  OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1489  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1490  ArrayRef<NamedAttribute> attributes) {
1491  build(builder, result, TypeRange{}, inputs, init);
1492  result.addAttributes(attributes);
1493 
1494  // Add output types for `RankedTensorType` output arguments.
1495  Type initType = init.getType();
1496  if (llvm::isa<RankedTensorType>(initType))
1497  result.addTypes(initType);
1498 
1499  if (bodyBuild)
1500  buildGenericRegion(builder, result.location, *result.regions.front(),
1501  inputs, /*outputs=*/{}, bodyBuild);
1502 }
1503 
1504 static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
1505  const OperationName &payloadOpName,
1506  const NamedAttrList &payloadOpAttrs,
1507  ArrayRef<Value> operands,
1508  bool initFirst = false) {
1509  OpBuilder b(parser.getContext());
1510  Region *body = result.addRegion();
1511  Block &block = body->emplaceBlock();
1512  b.setInsertionPointToStart(&block);
1513  for (auto &operand : operands) {
1514  block.addArgument(
1515  llvm::cast<ShapedType>(operand.getType()).getElementType(),
1516  b.getUnknownLoc());
1517  }
1518  SmallVector<Value> payloadOpOperands;
1519  // If initFirst flag is enabled, we consider init as the first position of
1520  // payload operands.
1521  if (initFirst) {
1522  payloadOpOperands.push_back(block.getArguments().back());
1523  for (const auto &arg : block.getArguments().drop_back())
1524  payloadOpOperands.push_back(arg);
1525  } else {
1526  payloadOpOperands = {block.getArguments().begin(),
1527  block.getArguments().end()};
1528  }
1529 
1530  Operation *payloadOp = b.create(
1531  result.location, b.getStringAttr(payloadOpName.getStringRef()),
1532  payloadOpOperands,
1533  TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1534  .getElementType()},
1535  payloadOpAttrs);
1536  b.create<YieldOp>(result.location, payloadOp->getResults());
1537 }
1538 
1539 ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1540  std::optional<OperationName> payloadOpName;
1541  NamedAttrList payloadOpAttrs;
1542  if (succeeded(parser.parseOptionalLBrace())) {
1543  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1544  if (failed(operationName))
1545  return failure();
1546  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1547  return failure();
1548  payloadOpName = operationName.value();
1549  if (parser.parseRBrace())
1550  return failure();
1551  }
1552 
1553  if (parseDstStyleOp(parser, result))
1554  return failure();
1555 
1556  if (payloadOpName.has_value()) {
1557  if (!result.operands.empty())
1558  addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1559  payloadOpAttrs,
1560  ArrayRef(result.operands).drop_back());
1561  else
1562  result.addRegion();
1563  } else {
1565  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1566  /*allowType=*/true, /*allowAttrs=*/true)) {
1567  return failure();
1568  }
1569  Region *body = result.addRegion();
1570  if (parser.parseRegion(*body, regionArgs))
1571  return failure();
1572  }
1573  return success();
1574 }
1575 
1576 // Retrieve the operation from the body, if it is the only one (except
1577 // yield) and if it gets the same amount of arguments as the body does.
1578 // If initFirst flag is enabled, we check that init takes the first position in
1579 // operands of payload.
1580 static Operation *findPayloadOp(Block *body, bool initFirst = false) {
1581  if (body->getOperations().size() != 2)
1582  return nullptr;
1583  Operation &payload = body->getOperations().front();
1584  assert(isa<YieldOp>(body->getOperations().back()));
1585 
1586  if (payload.getNumOperands() == 0 ||
1587  payload.getNumOperands() != body->getNumArguments())
1588  return nullptr;
1589  if (initFirst) {
1590  // check init
1591  if (payload.getOperands().back() != body->getArgument(0))
1592  return nullptr;
1593  // check rest
1594  for (const auto &[operand, bbArg] :
1595  llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1596  if (bbArg != operand)
1597  return nullptr;
1598  }
1599  } else {
1600  for (const auto &[operand, bbArg] :
1601  llvm::zip(payload.getOperands(), body->getArguments())) {
1602  if (bbArg != operand)
1603  return nullptr;
1604  }
1605  }
1606  return &payload;
1607 }
1608 
1609 void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1610  SmallVector<StringRef> elidedAttrs;
1611  std::string attrToElide;
1612  p << " { " << payloadOp->getName().getStringRef();
1613  for (const auto &attr : payloadOp->getAttrs()) {
1614  auto fastAttr =
1615  llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1616  if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1617  attrToElide = attr.getName().str();
1618  elidedAttrs.push_back(attrToElide);
1619  break;
1620  }
1621  }
1622  p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1623  p << " }";
1624 }
1625 
1626 void MapOp::print(OpAsmPrinter &p) {
1627  Block *mapper = getBody();
1628  Operation *payloadOp = findPayloadOp(mapper);
1629  if (payloadOp) {
1630  printShortForm(p, payloadOp);
1631  }
1632 
1633  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1634  p.printOptionalAttrDict((*this)->getAttrs());
1635 
1636  if (!payloadOp) {
1637  // Print region if the payload op was not detected.
1638  p.increaseIndent();
1639  p.printNewline();
1640  p << "(";
1641  llvm::interleaveComma(mapper->getArguments(), p,
1642  [&](auto arg) { p.printRegionArgument(arg); });
1643  p << ") ";
1644 
1645  p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1646  p.decreaseIndent();
1647  }
1648 }
1649 
1650 LogicalResult MapOp::verify() {
1651  auto *bodyBlock = getBody();
1652  auto blockArgs = bodyBlock->getArguments();
1653 
1654  // Checks if the number of `inputs` match the arity of the `mapper` region.
1655  if (getInputs().size() != blockArgs.size())
1656  return emitOpError() << "expects number of operands to match the arity of "
1657  "mapper, but got: "
1658  << getInputs().size() << " and " << blockArgs.size();
1659 
1660  // The parameters of mapper should all match the element type of inputs.
1661  for (const auto &[bbArgType, inputArg] :
1662  llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1663  auto inputElemType =
1664  llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1665  if (bbArgType != inputElemType) {
1666  return emitOpError() << "expected element type of input " << inputElemType
1667  << " to match bbArg type " << bbArgType;
1668  }
1669  }
1670 
1671  // The shape of each input must match the shape of the output.
1672  auto outputShape = getInit().getType().getShape();
1673  for (Type inputArgType : TypeRange{getInputs()}) {
1674  auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1675  if (inputElemShape != outputShape) {
1676  return emitOpError() << "expected shape of input (" << inputElemShape
1677  << ") to match shape of output (" << outputShape
1678  << ")";
1679  }
1680  }
1681 
1682  return success();
1683 }
1684 
1685 SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1686  int64_t rank = getInit().getType().getRank();
1687  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1688 }
1689 
1690 ArrayAttr MapOp::getIndexingMaps() {
1691  Builder builder(getContext());
1692  int64_t rank = getInit().getType().getRank();
1693  int64_t numIndexingMaps = getOperands().size();
1695  numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1696 }
1697 
1698 void MapOp::getEffects(
1700  &effects) {
1701  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1702 }
1703 
1704 Speculation::Speculatability MapOp::getSpeculatability() {
1705  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1706 }
1707 
1708 //===----------------------------------------------------------------------===//
1709 // ReduceOp
1710 //===----------------------------------------------------------------------===//
1711 
1712 void ReduceOp::getAsmBlockArgumentNames(Region &region,
1713  OpAsmSetValueNameFn setNameFn) {
1714  for (Value v : getRegionInputArgs())
1715  setNameFn(v, "in");
1716  for (Value v : getRegionOutputArgs())
1717  setNameFn(v, "init");
1718 }
1719 
1720 void ReduceOp::getAsmResultNames(
1721  function_ref<void(Value, StringRef)> setNameFn) {
1722  if (!getResults().empty())
1723  setNameFn(getResults().front(), "reduced");
1724 }
1725 
1726 void ReduceOp::build(
1727  OpBuilder &builder, OperationState &result, ValueRange inputs,
1728  ValueRange inits, ArrayRef<int64_t> dimensions,
1729  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1730  ArrayRef<NamedAttribute> attributes) {
1731  build(builder, result, TypeRange{}, inputs, inits, dimensions);
1732  result.addAttributes(attributes);
1733 
1734  // Add output types for `RankedTensorType` output arguments.
1735  for (Value init : inits) {
1736  Type initType = init.getType();
1737  if (llvm::isa<RankedTensorType>(initType))
1738  result.addTypes(initType);
1739  }
1740 
1741  if (bodyBuild)
1742  buildGenericRegion(builder, result.location, *result.regions.front(),
1743  inputs, inits, bodyBuild);
1744 }
1745 
1746 SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1747  int64_t inputRank =
1748  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1749  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1750  utils::IteratorType::parallel);
1751  for (int64_t reductionDim : getDimensions())
1752  iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1753  return iteratorTypes;
1754 }
1755 
1756 ArrayAttr ReduceOp::getIndexingMaps() {
1757  int64_t inputRank =
1758  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1759  SmallVector<AffineMap> affineMaps(
1760  getNumDpsInputs(),
1762  AffineMap resultMap =
1764  .dropResults(getDimensions());
1765  for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1766  affineMaps.push_back(resultMap);
1767  return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1768 }
1769 
1770 void ReduceOp::getEffects(
1772  &effects) {
1773  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1774 }
1775 
1776 Speculation::Speculatability ReduceOp::getSpeculatability() {
1777  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1778 }
1779 
1780 static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1781  NamedAttrList &attributes,
1782  StringRef attributeName) {
1783  if (parser.parseKeyword(attributeName) || parser.parseEqual())
1784  return failure();
1785 
1786  attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1787  return success();
1788 }
1789 
1790 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1791  std::optional<OperationName> payloadOpName;
1792  NamedAttrList payloadOpAttrs;
1793  if (succeeded(parser.parseOptionalLBrace())) {
1794  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1795  if (failed(operationName))
1796  return failure();
1797  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1798  return failure();
1799  payloadOpName = operationName.value();
1800  if (parser.parseRBrace())
1801  return failure();
1802  }
1803 
1804  if (parseDstStyleOp(
1805  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1806  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1807  }))
1808  return failure();
1809 
1810  if (payloadOpName.has_value()) {
1811  addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1812  ArrayRef(result.operands), /*initFirst=*/true);
1813  } else {
1815  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1816  /*allowType=*/true, /*allowAttrs=*/true)) {
1817  return failure();
1818  }
1819 
1820  Region *body = result.addRegion();
1821  if (parser.parseRegion(*body, regionArgs))
1822  return failure();
1823  }
1824 
1825  return success();
1826 }
1827 
1828 static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1829  ArrayRef<int64_t> attributeValue) {
1830  p << ' ' << attributeName << " = [" << attributeValue << "] ";
1831 }
1832 
1833 void ReduceOp::print(OpAsmPrinter &p) {
1834  Block *mapper = getBody();
1835  Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
1836  if (payloadOp) {
1837  printShortForm(p, payloadOp);
1838  }
1839 
1840  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1841  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1842  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1843  if (!payloadOp) {
1844  // Print region if the payload op was not detected.
1845  p.increaseIndent();
1846  p.printNewline();
1847  p << "(";
1848  llvm::interleaveComma(mapper->getArguments(), p,
1849  [&](auto arg) { p.printRegionArgument(arg); });
1850  p << ") ";
1851 
1852  p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1853  p.decreaseIndent();
1854  }
1855 }
1856 
1857 LogicalResult ReduceOp::verify() {
1858  ArrayRef<int64_t> dimensionsRef = getDimensions();
1859 
1860  for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1861  if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1862  llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1863  return emitOpError() << "expects all inputs to have the same shapes. "
1864  "Shape at input-index "
1865  << i
1866  << " is not equal to the shape at input-index 0.";
1867  }
1868  }
1869  for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1870  if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1871  llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1872  return emitOpError() << "expects all outputs to have the same shapes. "
1873  "Shape at output-index "
1874  << i
1875  << " is not equal to the shape at output-index 0.";
1876  }
1877  }
1878  auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1879  auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1880 
1881  DenseSet<int64_t> dimensionsToReduce;
1882  for (int64_t dimension : dimensionsRef) {
1883  if (dimension < 0 || dimension >= inputType.getRank()) {
1884  return emitOpError()
1885  << "dimensions for reduction should be in the range [0, "
1886  << inputType.getRank() - 1 << "].";
1887  }
1888  dimensionsToReduce.insert(dimension);
1889  }
1890 
1891  auto inputDims = inputType.getShape();
1892  auto initDims = initType.getShape();
1893 
1894  // Input dimensions that will be left after the reduction.
1895  SmallVector<int64_t> reducedInputDims;
1896  for (const auto &en : llvm::enumerate(inputDims)) {
1897  if (!dimensionsToReduce.count(en.index()))
1898  reducedInputDims.push_back(en.value());
1899  }
1900 
1901  if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1902  return emitOpError() << "number of dimensions after reduction "
1903  << reducedInputDims.size()
1904  << " doesn't match the init rank "
1905  << initType.getRank();
1906  }
1907 
1908  if (reducedInputDims != initDims)
1909  return emitOpError() << "init dimensions [" << initDims
1910  << "] doesn't match input dimensions after reduction ["
1911  << reducedInputDims << "]";
1912 
1913  Block *block = getBody();
1914  if (block->getNumArguments() != this->getNumOperands())
1915  return emitOpError()
1916  << "mismatching number of operands and block arguments";
1917 
1918  // Check that the first block arguments match the element type of the inputs.
1919  for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1920  Type inputElementType =
1921  llvm::cast<ShapedType>(input.getType()).getElementType();
1922  if (inputElementType != bbArg.getType())
1923  return emitOpError()
1924  << "input element type " << inputElementType
1925  << " does not match corresponding block argument type "
1926  << bbArg.getType();
1927  }
1928 
1929  // Check that the last block arguments match the element type of the outputs.
1930  for (auto [output, bbArg] : llvm::zip(
1931  getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1932  auto outputElementType =
1933  llvm::cast<ShapedType>(output.getType()).getElementType();
1934  if (outputElementType != bbArg.getType())
1935  return emitOpError()
1936  << "output element type " << outputElementType
1937  << " does not match corresponding block argument type "
1938  << bbArg.getType();
1939  }
1940  return success();
1941 }
1942 
1943 //===----------------------------------------------------------------------===//
1944 // TransposeOp
1945 //===----------------------------------------------------------------------===//
1946 
1947 static void buildIdentityRegion(OpBuilder &builder, Location loc,
1948  Region &region, ValueRange inputs,
1949  ValueRange outputs) {
1950  buildGenericRegion(builder, loc, region, inputs, outputs,
1951  [](OpBuilder &b, Location loc, ValueRange args) {
1952  if (!args.empty())
1953  b.create<linalg::YieldOp>(loc, args[0]);
1954  });
1955 }
1956 
1957 void TransposeOp::build(::mlir::OpBuilder &builder,
1958  ::mlir::OperationState &result, Value input, Value init,
1959  DenseI64ArrayAttr permutation,
1960  ArrayRef<NamedAttribute> attributes) {
1961  result.addOperands(input);
1962  result.addOperands(init);
1963  result.addAttribute(getPermutationAttrName(result.name), permutation);
1964  result.addAttributes(attributes);
1965 
1966  // Add output types for `RankedTensorType` output arguments.
1967  Type initType = init.getType();
1968  if (llvm::isa<RankedTensorType>(initType))
1969  result.addTypes(initType);
1970 
1971  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1972  init);
1973 }
1974 
1975 void TransposeOp::build(::mlir::OpBuilder &builder,
1976  ::mlir::OperationState &result, Value input, Value init,
1977  ArrayRef<int64_t> permutation,
1978  ArrayRef<NamedAttribute> attributes) {
1979  build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
1980  attributes);
1981 }
1982 
1983 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
1984  if (failed(parseDstStyleOp(
1985  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1986  return parseDenseI64ArrayAttr(parser, attributes, "permutation");
1987  })))
1988  return failure();
1989 
1990  OpBuilder builder(parser.getContext());
1991  buildIdentityRegion(builder, result.location, *result.addRegion(),
1992  /*inputs=*/result.operands,
1993  /*outputs=*/{});
1994  return success();
1995 }
1996 
1997 void TransposeOp::getAsmResultNames(
1998  function_ref<void(Value, StringRef)> setNameFn) {
1999  if (!getResults().empty())
2000  setNameFn(getResults().front(), "transposed");
2001 }
2002 
2004  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2005  printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
2006  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
2007 }
2008 
2009 LogicalResult TransposeOp::verify() {
2010  ArrayRef<int64_t> permutationRef = getPermutation();
2011 
2012  if (!isPermutationVector(permutationRef))
2013  return emitOpError("permutation is not valid");
2014 
2015  auto inputType = getInput().getType();
2016  auto initType = getInit().getType();
2017 
2018  int64_t rank = inputType.getRank();
2019 
2020  if (rank != initType.getRank())
2021  return emitOpError() << "input rank " << rank
2022  << " does not match init rank " << initType.getRank();
2023 
2024  if (rank != static_cast<int64_t>(permutationRef.size()))
2025  return emitOpError() << "size of permutation " << permutationRef.size()
2026  << " does not match the argument rank " << rank;
2027 
2028  auto inputDims = inputType.getShape();
2029  auto initDims = initType.getShape();
2030 
2031  for (int64_t i = 0; i < rank; ++i) {
2032  int64_t inputDim = inputDims[permutationRef[i]];
2033  int64_t initDim = initDims[i];
2034 
2035  if (inputDim != initDim) {
2036  return emitOpError() << "dim(result, " << i << ") = " << initDim
2037  << " doesn't match dim(input, permutation[" << i
2038  << "]) = " << inputDim;
2039  }
2040  }
2041 
2042  return success();
2043 }
2044 
2045 SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
2046  int64_t rank = getInit().getType().getRank();
2047  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2048 }
2049 
2050 ArrayAttr TransposeOp::getIndexingMaps() {
2051  Builder builder(getContext());
2052  int64_t rank = getInit().getType().getRank();
2053  return builder.getAffineMapArrayAttr(
2055  llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
2056  builder.getMultiDimIdentityMap(rank)});
2057 }
2058 
2059 void TransposeOp::getEffects(
2061  &effects) {
2062  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2063 }
2064 
2065 Speculation::Speculatability TransposeOp::getSpeculatability() {
2066  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2067 }
2068 
2069 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2071  // Only the tensor type is supported.
2072  if (!isa<TensorType>(getInput().getType()))
2073  return failure();
2074 
2075  // Single dimension transpose.
2076  if (getPermutation().size() == 0) {
2077  result.push_back(getInput());
2078  return success();
2079  }
2080  // Identity permutation.
2081  if (isIdentityPermutation(getPermutation())) {
2082  result.push_back(getInput());
2083  return success();
2084  }
2085 
2086  return failure();
2087 }
2088 
2089 /// Fold transpose with transpose.
2090 struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
2092 
2093  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2094  PatternRewriter &rewriter) const override {
2095  auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2096  if (!defTransposeOp)
2097  return failure();
2098  ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
2099  ArrayRef<int64_t> perms = transposeOp.getPermutation();
2100  SmallVector<int64_t> foldedPerms;
2101  foldedPerms.reserve(perms.size());
2102  for (int64_t perm : perms)
2103  foldedPerms.push_back(defPerms[perm]);
2104 
2105  rewriter.replaceOpWithNewOp<TransposeOp>(
2106  transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2107  foldedPerms);
2108  return success();
2109  }
2110 };
2111 
2112 /// This pattern canonicalize transpose by swapping the order of
2113 /// broadcast and transpose:
2114 /// transpose(broadcast(input)) -> broadcast(transpose(input))
2115 struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
2117 
2118  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2119  PatternRewriter &rewriter) const override {
2120  Value input = transposeOp.getInput();
2121  BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
2122  if (!input.hasOneUse() || !broadcastOp)
2123  return failure();
2124 
2125  ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2126  ArrayRef<int64_t> perms = transposeOp.getPermutation();
2127 
2128  // Get new perms and new dimensions.
2129  SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
2130  SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
2131  SmallVector<int64_t> resultDimensions;
2132  unsigned dimensionSize = dimensions.size();
2133  for (unsigned i = 0; i < dimensionSize; ++i)
2134  resultDimensions.push_back(invertPerm[dimensions[i]]);
2135 
2136  // Create transpose result.
2137  Value broadcastInput = broadcastOp.getInput();
2138  Location loc = transposeOp.getLoc();
2139  MLIRContext *ctx = transposeOp.getContext();
2141  auto broadcastInputTy =
2142  mlir::cast<RankedTensorType>(broadcastInput.getType());
2143  unsigned inputRank = broadcastInputTy.getRank();
2144  for (unsigned i = 0; i < inputRank; ++i) {
2145  if (broadcastInputTy.isDynamicDim(i)) {
2146  dims.push_back(rewriter.create<tensor::DimOp>(loc, broadcastInput, i)
2147  ->getResult(0));
2148  } else {
2149  dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2150  broadcastInputTy.getDimSize(i)));
2151  }
2152  }
2153  SmallVector<OpFoldResult> transposeResultShapes =
2154  applyPermutation(dims, resultPerms);
2155  Value transposeInit = rewriter.create<tensor::EmptyOp>(
2156  transposeOp.getLoc(), transposeResultShapes,
2157  broadcastInputTy.getElementType());
2158 
2159  // Create broadcast(transpose(input)).
2160  Value transposeResult =
2161  rewriter
2162  .create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
2163  resultPerms)
2164  ->getResult(0);
2165  rewriter.replaceOpWithNewOp<BroadcastOp>(
2166  transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2167  return success();
2168  }
2169 };
2170 
2171 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2172  MLIRContext *context) {
2174 }
2175 
2176 //===----------------------------------------------------------------------===//
2177 // BroadcastOp
2178 //===----------------------------------------------------------------------===//
2179 
2180 void BroadcastOp::build(::mlir::OpBuilder &builder,
2181  ::mlir::OperationState &result, Value input, Value init,
2182  DenseI64ArrayAttr dimensions,
2183  ArrayRef<NamedAttribute> attributes) {
2184  result.addOperands(input);
2185  result.addOperands(init);
2186  result.addAttribute(getDimensionsAttrName(result.name), dimensions);
2187  result.addAttributes(attributes);
2188 
2189  // Add output types for `RankedTensorType` output arguments.
2190  Type initType = init.getType();
2191  if (llvm::isa<RankedTensorType>(initType))
2192  result.addTypes(initType);
2193 
2194  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2195  init);
2196 }
2197 
2198 void BroadcastOp::build(::mlir::OpBuilder &builder,
2199  ::mlir::OperationState &result, Value input, Value init,
2200  ArrayRef<int64_t> dimensions,
2201  ArrayRef<NamedAttribute> attributes) {
2202  build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
2203  attributes);
2204 }
2205 
2206 ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
2207  if (failed(parseDstStyleOp(
2208  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2209  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
2210  })))
2211  return failure();
2212 
2213  OpBuilder builder(parser.getContext());
2214  buildIdentityRegion(builder, result.location, *result.addRegion(),
2215  /*inputs=*/result.operands,
2216  /*outputs=*/{});
2217  return success();
2218 }
2219 
2220 void BroadcastOp::getAsmResultNames(
2221  function_ref<void(Value, StringRef)> setNameFn) {
2222  if (!getResults().empty())
2223  setNameFn(getResults().front(), "broadcasted");
2224 }
2225 
2227  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2228  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
2229  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2230 }
2231 
2232 LogicalResult BroadcastOp::verify() {
2233  ArrayRef<int64_t> dimensionsRef = getDimensions();
2234 
2235  auto inputType = getInput().getType();
2236  auto initType = getInit().getType();
2237 
2238  int64_t inputRank = inputType.getRank();
2239  int64_t initRank = initType.getRank();
2240 
2241  auto inputShape = inputType.getShape();
2242  auto initShape = initType.getShape();
2243 
2244  if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2245  return emitOpError() << "input rank plus added dimensions does not "
2246  "match init rank. input rank: "
2247  << inputRank
2248  << ", dimensions size: " << dimensionsRef.size()
2249  << ", init rank: " << initRank;
2250 
2251  for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2252  if (dim < 0 || dim >= initRank)
2253  return emitOpError() << "dimension " << idx
2254  << " is out of range. expected range: [0, "
2255  << initRank - 1 << "], got: " << dim;
2256  }
2257 
2258  // Mapping from input dims to init dims.
2259  SmallVector<int64_t> dimMap;
2260  for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2261  if (!llvm::is_contained(dimensionsRef, dim))
2262  dimMap.push_back(dim);
2263  }
2264 
2265  for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2266  // This dimensions is mapped from the input. Init and input dims should
2267  // match.
2268  if (inputShape[inputDimIdx] != initShape[initDimIdx])
2269  return emitOpError() << "input dim " << inputDimIdx
2270  << " should match init dim " << initDimIdx
2271  << ". input: " << inputShape[inputDimIdx]
2272  << ", init: " << initShape[initDimIdx];
2273  }
2274 
2275  return success();
2276 }
2277 
2278 SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2279  int64_t rank = getInit().getType().getRank();
2280  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2281 }
2282 
2283 ArrayAttr BroadcastOp::getIndexingMaps() {
2284  Builder builder(getContext());
2285  int64_t rank = getInit().getType().getRank();
2286  return builder.getAffineMapArrayAttr(
2287  {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
2288  builder.getMultiDimIdentityMap(rank)});
2289 }
2290 
2291 void BroadcastOp::getEffects(
2293  &effects) {
2294  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2295 }
2296 
2297 Speculation::Speculatability BroadcastOp::getSpeculatability() {
2298  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2299 }
2300 
2301 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2302  MLIRContext *context) {
2303  results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
2304 }
2305 
2306 //===----------------------------------------------------------------------===//
2307 // YieldOp
2308 //===----------------------------------------------------------------------===//
2309 
2311  if (getNumOperands() > 0)
2312  p << ' ' << getOperands();
2313  p.printOptionalAttrDict((*this)->getAttrs());
2314  if (getNumOperands() > 0)
2315  p << " : " << getOperandTypes();
2316 }
2317 
2318 ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2320  SmallVector<Type, 2> types;
2321  SMLoc loc = parser.getCurrentLocation();
2322  return failure(parser.parseOperandList(opInfo) ||
2323  parser.parseOptionalAttrDict(result.attributes) ||
2324  (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2325  parser.resolveOperands(opInfo, types, loc, result.operands));
2326 }
2327 
2328 // Check the operand number and types must match the element types of the
2329 // LinalgOp interface's shaped operands.
2330 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2331  if (op.getNumOperands() != linalgOp.getNumDpsInits())
2332  return op.emitOpError("expected number of yield values (")
2333  << op.getNumOperands()
2334  << ") to match the number of inits / outs operands of the enclosing "
2335  << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2336 
2337  for (OpOperand &opOperand : op->getOpOperands()) {
2338  OpOperand *outputOperand =
2339  linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2340  Type elementType = outputOperand->get().getType();
2341  if (isa<MemRefType, RankedTensorType>(elementType))
2342  elementType = getElementTypeOrSelf(outputOperand->get().getType());
2343  if (opOperand.get().getType() != elementType)
2344  return op.emitOpError("type of yield operand ")
2345  << (opOperand.getOperandNumber() + 1) << " ("
2346  << opOperand.get().getType() << ") doesn't match "
2347  << "the element type of the enclosing linalg.generic op ("
2348  << elementType << ")";
2349  }
2350  return success();
2351 }
2352 
2353 LogicalResult linalg::YieldOp::verify() {
2354  auto *parentOp = (*this)->getParentOp();
2355  if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2356  return emitOpError("expected single non-empty parent region");
2357 
2358  if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2359  return verifyYield(*this, linalgOp);
2360 
2361  return emitOpError("expected parent op with LinalgOp interface");
2362 }
2363 
2364 //===----------------------------------------------------------------------===//
2365 // IndexOp
2366 //===----------------------------------------------------------------------===//
2367 
2368 LogicalResult IndexOp::verify() {
2369  auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2370  if (!linalgOp)
2371  return emitOpError("expected parent op with LinalgOp interface");
2372  if (linalgOp.getNumLoops() <= getDim())
2373  return emitOpError("expected dim (")
2374  << getDim() << ") to be lower than the number of loops ("
2375  << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2376  return success();
2377 }
2378 
2379 OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2380  auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2381  // Bail out if `linalg.index` does not have a proper parent yet at this
2382  // point, e.g., when calling `createOrFold` during IR construction in
2383  // `genericOp::build`.
2384  if (!linalgOp)
2385  return OpFoldResult{};
2386 
2387  // Index of unit dims is always 0.
2388  SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2389  uint64_t dim = getDim();
2390  assert(dim < loopBounds.size() && "Dim is out of bounds");
2391  if (loopBounds[dim] == 1)
2393 
2394  return OpFoldResult{};
2395 }
2396 
2397 /////// Operations corresponding to library calls defined with Tablegen ////////
2398 
2399 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2400 
2401 #define GET_OP_CLASSES
2402 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2403 
2404 #define GET_OP_CLASSES
2405 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2406 #define GET_OP_CLASSES
2407 #include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2408 
2409 AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2410  unsigned rank,
2411  MLIRContext *context) {
2412  if (maybeMap)
2413  return *maybeMap;
2414  if (rank == 0)
2415  return AffineMap::get(context);
2416  return AffineMap::getMultiDimIdentityMap(rank, context);
2417 }
2418 
2420 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2421  MLIRContext *context) {
2423  res.reserve(num);
2424  for (unsigned i = 0; i < num; ++i)
2425  res.push_back(getAffineDimExpr(startIdx++, context));
2426  return res;
2427 }
2428 
2431  auto rangeA = llvm::make_range(a.begin(), a.end());
2432  auto rangeB = llvm::make_range(b.begin(), b.end());
2433  auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2434  return llvm::to_vector<4>(concatRanges);
2435 }
2436 
2437 static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2438  if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2439  ss << "view";
2440  for (auto size : memref.getShape())
2441  if (size < 0)
2442  ss << "sx";
2443  else
2444  ss << size << "x";
2445  if (failed(appendMangledType(ss, memref.getElementType())))
2446  return failure();
2447  if (auto as = memref.getMemorySpace()) {
2448  if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2449  ss << "as" << attr.getInt();
2450  else
2451  return failure();
2452  }
2453  return success();
2454  }
2455  if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2456  ss << "vector";
2457  llvm::interleave(
2458  vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2459  if (failed(appendMangledType(ss, vec.getElementType())))
2460  return failure();
2461  return success();
2462  }
2463  if (t.isSignlessIntOrIndexOrFloat()) {
2464  ss << t;
2465  return success();
2466  }
2467  return failure();
2468 }
2469 
2471  assert(isa<LinalgOp>(op));
2472  std::string name(op->getName().getStringRef().str());
2473  std::string fun = "";
2474  for (NamedAttribute kv : op->getAttrs()) {
2475  if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2476  fun = stringifyEnum(ufa.getValue()).str() + "_";
2477  } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2478  fun = stringifyEnum(bfa.getValue()).str() + "_";
2479  }
2480  }
2481  name.reserve(128);
2482  llvm::replace(name, '.', '_');
2483  llvm::raw_string_ostream ss(name);
2484  ss << "_" << fun;
2485  for (Type t : op->getOperandTypes()) {
2486  if (failed(appendMangledType(ss, t)))
2487  return std::string();
2488  ss << "_";
2489  }
2490  name.pop_back();
2491  return name;
2492 }
2493 
2494 //===----------------------------------------------------------------------===//
2495 // Canonicalizers and Folders.
2496 //===----------------------------------------------------------------------===//
2497 
2498 namespace {
2499 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2501 
2502  LogicalResult matchAndRewrite(LinalgOp op,
2503  PatternRewriter &rewriter) const override {
2504  for (OpOperand &opOperand : op->getOpOperands()) {
2505  // Linalg "inputs" may be either tensor or memref type.
2506  // tensor<0xelt_type> is a convention that may not always mean
2507  // "0 iterations". Only erase in cases we see memref<...x0x...>.
2508  auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2509  if (!mt)
2510  continue;
2511  if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2512  rewriter.eraseOp(op);
2513  return success();
2514  }
2515  }
2516  return failure();
2517  }
2518 };
2519 
2520 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2521 /// result that is more static than the linalg op.
2522 struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2524 
2525  LogicalResult matchAndRewrite(tensor::CastOp castOp,
2526  PatternRewriter &rewriter) const override {
2527  if (!tensor::canFoldIntoProducerOp(castOp))
2528  return failure();
2529 
2530  auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2531  if (!linalgOp)
2532  return failure();
2533 
2534  // Cast can be in conditionally reachable region, if which case folding will
2535  // generate invalid code. Only conservatively fold ops in same block for
2536  // now.
2537  if (castOp->getBlock() != linalgOp->getBlock())
2538  return failure();
2539 
2540  OpBuilder::InsertionGuard guard(rewriter);
2541  rewriter.setInsertionPoint(linalgOp);
2542 
2543  Location loc = linalgOp.getLoc();
2544  OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2545  unsigned resultNumber = resultValue.getResultNumber();
2546  auto resultType =
2547  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2548  // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2549  // going from a more dynamic shape to a less dynamic shape. If the producer
2550  // for this cast, i.e. producer of the out operand, is also an operation
2551  // that folds with tensor.cast consumer (like this pattern), the cast will
2552  // continue to propagate as far up the stack as it can go.
2553  OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2554  Value newOperand =
2555  rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
2556  SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2557  SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2558  linalgOp.getDpsInits().end());
2559  outputOperands[resultNumber] = newOperand;
2560  newOperands.append(outputOperands.begin(), outputOperands.end());
2561 
2562  SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2563  linalgOp->result_type_end());
2564  resultTypes[resultNumber] = resultType;
2565  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2566 
2567  // Create a tensor.cast operation back to the original type.
2568  Value castBack = rewriter.create<tensor::CastOp>(
2569  loc, resultValue.getType(), newOp->getResult(resultNumber));
2570 
2571  SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2572  results[resultNumber] = castBack;
2573  rewriter.replaceOp(linalgOp, results);
2574  rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2575  return success();
2576  }
2577 };
2578 
2579 /// For each of the operand in `operands` this function maps the static sizes of
2580 /// dimensions to their affine dim expressions.
2581 static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2582  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2583  for (OpOperand &opOperand : operands) {
2584  if (linalgOp.isScalar(&opOperand))
2585  continue;
2586  Value src = opOperand.get();
2587  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2588  auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2589 
2590  // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2591  // `tensor.cast` operation and source of the cast operation has a static
2592  // shape, then assign it to the `sourceShape`.
2593  auto *parentOp = src.getDefiningOp();
2594  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2595  if (parentOp) {
2596  if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2597  Value castSource = castOp.getSource();
2598  auto castSourceType =
2599  llvm::dyn_cast<RankedTensorType>(castSource.getType());
2600  if (castSourceType && castSourceType.hasStaticShape())
2601  sourceShape = castSourceType.getShape();
2602  }
2603  }
2604 
2605  // If the source shape's dimension has a static shape, map the affine dim
2606  // expression to the known static size.
2607  for (unsigned i = 0; i < sourceShape.size(); i++) {
2608  if (sourceType.isDynamicDim(i))
2609  continue;
2610  if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2611  affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2612  }
2613  }
2614 }
2615 
2616 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2617 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2618 /// their result types is stored in `resultTypes`. If `opOperand` requires no
2619 /// change then `changeNeeded` is false and same operand is added in the
2620 /// `newOperands` list.
2621 static void createNewOperandWithStaticSizes(
2622  Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2623  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2624  SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2625  bool &changeNeeded) {
2626  Value src = opOperand->get();
2627  newOperands.push_back(src);
2628  if (linalgOp.isScalar(opOperand))
2629  return;
2630  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2631  Type resultType = sourceType;
2632  if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2633  resultTypes.push_back(resultType);
2634  return;
2635  }
2636  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2637  AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2638  SmallVector<int64_t> newShape;
2639  // If operand is updated with new shape, `newOperandNeeded` will be
2640  // true.
2641  bool newOperandNeeded = false;
2642  for (unsigned i = 0; i < sourceShape.size(); i++) {
2643  int64_t dimShape = sourceShape[i];
2644  AffineExpr dimExpr = sourceMap.getResult(i);
2645  if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2646  newShape.push_back(dimShape);
2647  continue;
2648  }
2649  // Dimension has a dynamic shape and corresponding affine dim
2650  // expression is present in the map. So assign the size for the
2651  // given affine dim expression to the dimension.
2652  newShape.push_back(affineExprToSize[dimExpr]);
2653  newOperandNeeded = true;
2654  }
2655  resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
2656  sourceType.getEncoding());
2657  if (newOperandNeeded) {
2658  changeNeeded = true;
2659  // Get the new operand value given its size and element type by
2660  // casting it.
2661  Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
2662  unsigned index = opOperand->getOperandNumber();
2663  newOperands[index] = newOperand;
2664  }
2665  if (linalgOp.isDpsInit(opOperand))
2666  resultTypes.push_back(resultType);
2667 }
2668 
2669 /// Static shapes for the operands can be inferred if any one of the operands
2670 /// have a static shape. This can be done by referring to the affine dim
2671 /// expressions for the operand.
2672 struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2674 
2675  LogicalResult matchAndRewrite(LinalgOp linalgOp,
2676  PatternRewriter &rewriter) const override {
2677  if (!linalgOp.hasPureTensorSemantics())
2678  return failure();
2679 
2680  // Maps must be projected permutations.
2681  if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2682  return !map.isProjectedPermutation();
2683  }))
2684  return failure();
2685 
2686  // Maps affine dim expressions to the static size of that dimension.
2687  llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2688  Location loc = linalgOp.getLoc();
2689 
2690  // For each of the affine dim expression, check if the size is known. If
2691  // known add that in the map.
2692  populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2693 
2694  SmallVector<Value> newOperands;
2695  SmallVector<Type> resultTypes;
2696 
2697  // `changeNeeded` is `false` if the operands of `linalgOp` require no
2698  // change in their types.
2699  bool changeNeeded = false;
2700  newOperands.reserve(linalgOp->getNumOperands());
2701  resultTypes.reserve(linalgOp.getNumDpsInits());
2702 
2703  // Iterate over all the operands and update the static sizes.
2704  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2705  createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2706  affineExprToSize, linalgOp, newOperands,
2707  resultTypes, changeNeeded);
2708  }
2709 
2710  // If the generic op has all the required static information, no
2711  // canonicalization needed.
2712  if (!changeNeeded)
2713  return failure();
2714 
2715  // Clone op.
2716  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2717  SmallVector<Value> replacements;
2718  replacements.reserve(newOp->getNumResults());
2719  for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2720  Value newResult = std::get<1>(it);
2721  Value oldResult = std::get<0>(it);
2722  Type newType = newResult.getType();
2723  Type oldType = oldResult.getType();
2724  replacements.push_back(
2725  (newType != oldType)
2726  ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
2727  : newResult);
2728  }
2729  rewriter.replaceOp(linalgOp, replacements);
2730  return success();
2731  }
2732 };
2733 
2734 } // namespace
2735 
2736 // All named ops canonicalizers and folders are auto-generated in the
2737 // .cpp.inc.
2738 
2739 //===----------------------------------------------------------------------===//
2740 // SoftmaxOp
2741 //===----------------------------------------------------------------------===//
2742 
2743 LogicalResult SoftmaxOp::verify() {
2744  ShapedType inputType = getInputOperandType();
2745  ShapedType outputType = getOutputOperandType();
2746 
2747  ArrayRef<int64_t> inputShape = inputType.getShape();
2748  ArrayRef<int64_t> outputShape = outputType.getShape();
2749  if (failed(verifyCompatibleShape(inputShape, outputShape)))
2750  return emitOpError("incompatible output shape");
2751 
2752  int64_t inputRank = getInputOperandRank();
2753  int64_t dimension = getDimension();
2754  if ((dimension < 0) || (dimension >= inputRank))
2755  return emitOpError("incorrect dimension specified");
2756 
2757  return success();
2758 }
2759 
2760 SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2761  int64_t operandRank = getInputOperandRank();
2762  SmallVector<Range> loopBounds(operandRank);
2763  Location loc = getLoc();
2764  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
2765  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
2766  Value source = getInput();
2767  for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2768  loopBounds[dim].offset = zero;
2769  loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2770  loopBounds[dim].stride = one;
2771  }
2772  return loopBounds;
2773 }
2774 
2775 SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2776  SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2777  utils::IteratorType::parallel);
2778  iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2779  return iteratorTypes;
2780 }
2781 
2782 FailureOr<TilingResult>
2784  ArrayRef<OpFoldResult> offsets,
2785  ArrayRef<OpFoldResult> sizes) {
2786  int64_t rank = getInputOperandRank();
2787  auto oneAttr = builder.getI64IntegerAttr(1);
2788  SmallVector<OpFoldResult> strides(rank, oneAttr);
2789  SmallVector<Value> tiledOperands;
2790  Operation *inputSlice =
2791  getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2792  if (!inputSlice) {
2793  return emitOpError("failed to compute input slice");
2794  }
2795  tiledOperands.emplace_back(inputSlice->getResult(0));
2796  Operation *outputSlice =
2797  getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2798  if (!outputSlice) {
2799  return emitOpError("failed to compute output slice");
2800  }
2801  tiledOperands.emplace_back(outputSlice->getResult(0));
2802 
2803  SmallVector<Type, 4> resultTypes;
2804  if (hasPureTensorSemantics())
2805  resultTypes.push_back(tiledOperands[1].getType());
2806  Operation *tiledOp =
2807  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2808 
2809  return TilingResult{
2810  {tiledOp},
2811  SmallVector<Value>(tiledOp->getResults()),
2812  llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2813 }
2814 
2815 LogicalResult SoftmaxOp::getResultTilePosition(
2816  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2817  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2818  SmallVector<OpFoldResult> &resultSizes) {
2819  if (resultNumber == 0) {
2820  resultOffsets.assign(offsets.begin(), offsets.end());
2821  resultSizes.assign(sizes.begin(), sizes.end());
2822  return success();
2823  }
2824  return failure();
2825 }
2826 
2827 // cast(dynamic) -> static.
2828 LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2829  return memref::foldMemRefCast(*this);
2830 }
2831 
2832 LogicalResult
2834  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2836  Location loc = getOperation()->getLoc();
2837  IRRewriter rewriter(b);
2838  auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2839  auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2840  for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2841  if (!outputShapedType.isDynamicDim(dim)) {
2842  // Static dim: Return IntegerAttr.
2843  shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2844  } else {
2845  // Dynamic dim: Return Value.
2846  OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2847  shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2848  }
2849  }
2850  reifiedReturnShapes.emplace_back(std::move(shapes));
2851  return success();
2852 }
2853 
2854 void SoftmaxOp::getEffects(
2856  &effects) {
2857  for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2858  if (!llvm::isa<MemRefType>(operand.getType()))
2859  continue;
2860  effects.emplace_back(MemoryEffects::Read::get(),
2861  &getOperation()->getOpOperand(index), /*stage=*/0,
2862  /*effectOnFullRegion=*/true,
2864  }
2865 
2866  for (OpOperand &operand : getDpsInitsMutable()) {
2867  if (!llvm::isa<MemRefType>(operand.get().getType()))
2868  continue;
2869  effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
2870  /*effectOnFullRegion=*/true,
2872  effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
2873  /*effectOnFullRegion=*/true,
2875  }
2876 }
2877 
2878 // Helper functions for softmax decomposition.
2879 // @{
2880 
2881 // Helper function to produce the iterator types (reduction or parallel) and
2882 // affine maps for the iterators used in the decomposition of softmax.
2883 // This method creates:
2884 // If allParallel == true:
2885 // - iterator type: {parallel, ..., parallel}
2886 // - affine maps:
2887 // -- identity with inputRank dimensions.
2888 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2889 // where N == inputRank.
2890 //
2891 // If allParallel == false:
2892 // - iterator type at dim(i) == parallel for i != \p dim and
2893 // dim(dim) == reduction.
2894 // - affine map:
2895 // -- identity with inputRank dimensions.
2896 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2897 // where N == inputRank.
2898 static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2900  int64_t dim, bool allParallel = false) {
2901  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2902  utils::IteratorType::parallel);
2903  if (!allParallel)
2904  iteratorTypes[dim] = utils::IteratorType::reduction;
2905  MLIRContext *ctxt = builder.getContext();
2906  auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2907  SmallVector<AffineExpr, 2> affineExprs;
2908  for (int i = 0; i < inputRank; i++) {
2909  if (i != dim)
2910  affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2911  }
2912  auto reductionMap =
2913  AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2914  SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2915  return std::make_tuple(iteratorTypes, indexingMaps);
2916 }
2917 
2918 // Helper function to produce a linalg.generic that computes a reduction on
2919 // dimension \p dim with the operation type \p T.
2920 template <typename T>
2921 static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2922  int64_t dim) {
2923  auto inputType = cast<ShapedType>(input.getType());
2924  ArrayRef<int64_t> inputShape = inputType.getShape();
2925  int64_t inputRank = inputShape.size();
2926  auto [iteratorTypes, indexingMaps] =
2927  computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2928  assert(indexingMaps.size() == 2 &&
2929  "We should have two maps: 1 for the input, 1 for the output");
2930  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2931 
2932  auto genericOp = builder.create<linalg::GenericOp>(
2933  loc, output.getType(), input, output, indexingMaps, iteratorTypes,
2934  [&](OpBuilder &b, Location loc, ValueRange args) {
2935  Value result = b.create<T>(loc, args[0], args[1]);
2936  b.create<linalg::YieldOp>(loc, result);
2937  });
2938  return genericOp.getResult(0);
2939 }
2940 
2941 /// Produce a linalg generic that computes the second step of the softmax
2942 /// decomposition: res = exp(input - max), where \p max is the max of \p input
2943 /// on dimension \p dim.
2944 static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
2945  Value max, Value output, int64_t dim) {
2946  auto inputType = cast<ShapedType>(input.getType());
2947  ArrayRef<int64_t> inputShape = inputType.getShape();
2948  int64_t inputRank = inputShape.size();
2949  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2950  builder, inputRank, dim, /*allParallel=*/true);
2951  assert(indexingMaps.size() == 2 && "We should have one map for each input");
2952  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2953  // Add the affine map for the output argument.
2954  indexingMaps.push_back(indexingMaps[0]);
2955  auto genericOp = builder.create<linalg::GenericOp>(
2956  loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
2957  iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
2958  Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2959  Value result = b.create<math::ExpOp>(loc, diff);
2960  b.create<linalg::YieldOp>(loc, result);
2961  });
2962  return genericOp.getResult(0);
2963 }
2964 
2965 /// Produce a linalg generic that computes the final step of the softmax
2966 /// decomposition.
2967 /// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
2968 /// yield n / d
2969 /// }
2970 static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
2971  Value denominator, Value output, int64_t dim) {
2972  auto inputType = cast<ShapedType>(numerator.getType());
2973  ArrayRef<int64_t> inputShape = inputType.getShape();
2974  int64_t inputRank = inputShape.size();
2975  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2976  builder, inputRank, dim, /*allParallel=*/true);
2977  assert(indexingMaps.size() == 2 &&
2978  "We should have one map for each input (2)");
2979  assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
2980  // Add the affine map for the output tensor.
2981  indexingMaps.push_back(indexingMaps[0]);
2982  auto genericOp = builder.create<linalg::GenericOp>(
2983  loc, numerator.getType(), ValueRange{numerator, denominator}, output,
2984  indexingMaps, iteratorTypes,
2985  [&](OpBuilder &b, Location loc, ValueRange args) {
2986  Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2987  b.create<linalg::YieldOp>(loc, result);
2988  });
2989  return genericOp.getResult(0);
2990 }
2991 // @} End helper functions for softmax decomposition.
2992 
2993 /// Given an N-dimensional tensor x, this method converts
2994 /// softmax(x) to the following sequence of operations:
2995 ///
2996 /// 1. Compute the max of x along dimension d. This results
2997 /// in a N-1 dimensional tensor m.
2998 /// m = max(x, dim = d)
2999 ///
3000 /// 2. Subtract a broadcasted m from x and exponentiate. This results in
3001 /// a N dimensional tensor z.
3002 /// z = exp(x - m)
3003 ///
3004 /// 3. Compute the sum of z along dimension d. This results in
3005 /// a N-1 dimensional tensor l.
3006 /// l = sum(z, dim = d)
3007 ///
3008 /// 4. Divide z and l. This gives the N-dimensional softmax.
3009 /// softmax = z / l
3010 ///
3011 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
3012  OpBuilder::InsertionGuard guard(b);
3013  b.setInsertionPoint(*this);
3014  Location loc = getLoc();
3015  Value input = getInput();
3016  ShapedType inputType = getInputOperandType();
3017  Type elementType = inputType.getElementType();
3018  int64_t reductionDim = getDimension();
3019  SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
3020  Value output = getOutput();
3021  dims.erase(dims.begin() + reductionDim);
3022  // Step 1: Compute max along dim.
3023  Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
3024  Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf,
3025  elementType, b, loc,
3026  /*useOnlyFiniteValue=*/true);
3027  Value neutralForMaxFInit =
3028  b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
3029  .result();
3030  Value max =
3031  reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
3032 
3033  // Step 2: Subtract max from input and exponentiate.
3034  Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
3035 
3036  // Step 3: Compute sum along dim.
3037  Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
3038  b, loc, /*useOnlyFiniteValue=*/true);
3039  Value zeroInit =
3040  b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
3041  Value denominator =
3042  reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
3043 
3044  // Step 4: Compute softmax.
3045  Value result =
3046  buildDivOp(b, loc, numerator, denominator, output, reductionDim);
3047  return SmallVector<Value>{result};
3048 }
3049 
3050 //===----------------------------------------------------------------------===//
3051 // WinogradFilterTransformOp
3052 //===----------------------------------------------------------------------===//
3053 
3054 LogicalResult WinogradFilterTransformOp::verify() {
3055  auto filterType = cast<ShapedType>(getFilter().getType());
3056  ArrayRef<int64_t> filterShape = filterType.getShape();
3057  int64_t filterH = filterShape[getFilterHDim()];
3058  int64_t filterW = filterShape[getFilterWDim()];
3059  WinogradConv2DFmr fmr = getFmr();
3060  int64_t m, r;
3061  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3062 
3063  if (filterH != r && filterH != 1)
3064  return emitOpError("expect filter height either equals to r or 1");
3065  if (filterW != r && filterW != 1)
3066  return emitOpError("expect filter width either equals to r or 1");
3067  if (filterH == 1 && filterW == 1)
3068  return emitOpError("expect either filter height or width equals to r");
3069 
3070  SmallVector<int64_t> expectedOutputShape;
3071  expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3072  expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3073  expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3074  expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3075 
3076  auto outputType = cast<ShapedType>(getOutput().getType());
3077  ArrayRef<int64_t> outputShape = outputType.getShape();
3078  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3079  return emitOpError("the output shape is not expected");
3080  }
3081  return success();
3082 }
3083 
3085 WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
3086  Location loc = getLoc();
3087  IntegerAttr zeroAttr = builder.getIndexAttr(0);
3088  IntegerAttr oneAttr = builder.getIndexAttr(1);
3089  Value filter = getFilter();
3090  int64_t filterRank = getFilterOperandRank();
3091  SmallVector<Range> loopBounds(filterRank);
3092  for (unsigned dim = 0; dim < filterRank; ++dim) {
3093  loopBounds[dim].offset = zeroAttr;
3094  loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
3095  loopBounds[dim].stride = oneAttr;
3096  }
3097  return loopBounds;
3098 }
3099 
3101 WinogradFilterTransformOp::getLoopIteratorTypes() {
3102  int64_t filterRank = getFilterOperandRank();
3103  SmallVector<utils::IteratorType> iteratorTypes(filterRank,
3104  utils::IteratorType::parallel);
3105  return iteratorTypes;
3106 }
3107 
3109  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3110  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3111  SmallVector<OpFoldResult> &resultSizes) {
3112  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3113  ShapedType filterType = getFilterOperandType();
3114  ArrayRef<int64_t> filterShape = filterType.getShape();
3115  int64_t filterH = filterShape[getFilterHDim()];
3116  int64_t filterW = filterShape[getFilterWDim()];
3117  WinogradConv2DFmr fmr = getFmr();
3118  int64_t m, r;
3119  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3120  int64_t alpha = m + r - 1;
3121  int64_t alphaH = filterH != 1 ? alpha : 1;
3122  int64_t alphaW = filterW != 1 ? alpha : 1;
3123  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3124  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3125 
3126  resultOffsets.append(
3127  {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3128  resultSizes.append(
3129  {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3130 
3131  return success();
3132 }
3133 
3134 /// Implement tiling for winograd_filter_transform
3135 /// The input of winograd_filter_transform is (F, KH, KW, C).
3136 /// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
3137 /// Users can specify the tile sizes of F and C.
3138 /// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
3139 /// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
3141  OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3142  ArrayRef<OpFoldResult> sizes) {
3143  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3144  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3145  ShapedType filterType = getFilterOperandType();
3146  ArrayRef<int64_t> filterShape = filterType.getShape();
3147  int64_t filterH = filterShape[getFilterHDim()];
3148  int64_t filterW = filterShape[getFilterWDim()];
3149  IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
3150  IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
3151  SmallVector<Value> tiledOperands;
3152  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3153 
3154  sliceOffsets.append(
3155  {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3156  sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3157  sizes[getFilterCDim()]});
3158  int64_t filterRank = getFilterOperandRank();
3159  SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3160  Location loc = getLoc();
3161  auto filterSlice = builder.create<tensor::ExtractSliceOp>(
3162  loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3163  tiledOperands.emplace_back(filterSlice);
3164 
3165  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3166  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3167  resultSizes)))
3168  return failure();
3169 
3170  int64_t outputRank = getOutputOperandRank();
3171  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3172  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3173  loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3174  tiledOperands.emplace_back(outputSlice);
3175 
3176  SmallVector<Type> resultTypes;
3177  resultTypes.push_back(tiledOperands[1].getType());
3178  Operation *tiledOp =
3179  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3180 
3181  return TilingResult{
3182  {tiledOp},
3183  SmallVector<Value>(tiledOp->getResults()),
3184  llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3185 }
3186 
3187 //===----------------------------------------------------------------------===//
3188 // WinogradInputTransformOp
3189 //===----------------------------------------------------------------------===//
3190 
3191 LogicalResult WinogradInputTransformOp::verify() {
3192  auto inputType = cast<ShapedType>(getInput().getType());
3193  ArrayRef<int64_t> inputShape = inputType.getShape();
3194  int64_t inputH = inputShape[getInputHDim()];
3195  int64_t inputW = inputShape[getInputWDim()];
3196  WinogradConv2DFmr fmr = getFmr();
3197  int64_t m, r;
3198  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3199  int64_t tileSize = m + r - 1;
3200 
3201  auto outputType = cast<ShapedType>(getOutput().getType());
3202  ArrayRef<int64_t> outputShape = outputType.getShape();
3203  bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3204  bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3205 
3206  SmallVector<int64_t> expectedOutputShape(6, inputH);
3207  if (ShapedType::isDynamic(inputH)) {
3208  expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3209  expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3210  } else {
3211  expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3212  expectedOutputShape[getOutputTileHDim()] =
3213  leftTransform ? (inputH - (r - 1)) / m : inputH;
3214  }
3215  if (ShapedType::isDynamic(inputW)) {
3216  expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3217  expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3218  } else {
3219  expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3220  expectedOutputShape[getOutputTileWDim()] =
3221  rightTransform ? (inputW - (r - 1)) / m : inputW;
3222  }
3223  expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3224  expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3225 
3226  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3227  return emitOpError("the output shape is not expected");
3228  }
3229  return success();
3230 }
3231 
3233 WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3234  Location loc = getLoc();
3235  IntegerAttr zeroAttr = builder.getIndexAttr(0);
3236  IntegerAttr oneAttr = builder.getIndexAttr(1);
3237  Value output = getOutput();
3238  int64_t outputRank = getOutputOperandRank();
3239  SmallVector<Range> loopBounds(outputRank);
3240  for (unsigned dim = 0; dim < outputRank; ++dim) {
3241  loopBounds[dim].offset = zeroAttr;
3242  // alphaH, alphaW, tileH, tileW, N, C
3243  loopBounds[dim].size = getDimValue(builder, loc, output, dim);
3244  loopBounds[dim].stride = oneAttr;
3245  }
3246  return loopBounds;
3247 }
3248 
3250 WinogradInputTransformOp::getLoopIteratorTypes() {
3251  int64_t outputRank = getOutputOperandRank();
3252  SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3253  utils::IteratorType::parallel);
3254  return iteratorTypes;
3255 }
3256 
3258  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3259  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3260  SmallVector<OpFoldResult> &resultSizes) {
3261  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3262  ShapedType outputType = getOutputOperandType();
3263  ArrayRef<int64_t> outputShape = outputType.getShape();
3264  int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3265  int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3266 
3267  WinogradConv2DFmr fmr = getFmr();
3268  int64_t m, r;
3269  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3270  int64_t alpha = m + r - 1;
3271  int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3272  int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3273 
3274  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3275  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3276 
3277  resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3278  offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3279  offsets[getOutputCDim()]});
3280  resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3281  sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3282  sizes[getOutputCDim()]});
3283 
3284  return success();
3285 }
3286 
3287 /// Implement tiling for winograd_input_transform
3288 /// The input of winograd_input_transform is (N, H, W, C).
3289 /// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3290 /// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3291 /// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3292 /// the values for the sizes of tileH, tileW, N, C for one tile.
3293 FailureOr<TilingResult>
3295  ArrayRef<OpFoldResult> offsets,
3296  ArrayRef<OpFoldResult> sizes) {
3297  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3298  WinogradConv2DFmr fmr = getFmr();
3299  int64_t m, r;
3300  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3301 
3302  ShapedType outputType = getOutputOperandType();
3303  ArrayRef<int64_t> outputShape = outputType.getShape();
3304  int64_t alphaH = outputShape[getOutputAlphaHDim()];
3305  int64_t alphaW = outputShape[getOutputAlphaWDim()];
3306 
3307  Location loc = getLoc();
3308  MLIRContext *context = builder.getContext();
3309  auto identityAffineMap =
3310  AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3311  auto offsetAffineMap =
3312  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3313  Value mappedOffsetH = affine::makeComposedAffineApply(
3314  builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3315  offsets[getOutputTileHDim()]);
3316  Value mappedOffsetW = affine::makeComposedAffineApply(
3317  builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3318  offsets[getOutputTileWDim()]);
3319  auto sizeAffineMap = AffineMap::get(
3320  1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
3321  Value mappedSizeH = affine::makeComposedAffineApply(
3322  builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3323  Value mappedSizeW = affine::makeComposedAffineApply(
3324  builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3325 
3326  SmallVector<Value> tiledOperands;
3327  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3328 
3329  OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3330  OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3331  sliceOffsets.append(
3332  {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3333  OpFoldResult sizeH =
3334  alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3335  OpFoldResult sizeW =
3336  alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3337  sliceSizes.append(
3338  {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3339  int64_t inputRank = getInputOperandRank();
3340  SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3341  auto inputSlice = builder.create<tensor::ExtractSliceOp>(
3342  loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3343  tiledOperands.emplace_back(inputSlice);
3344 
3345  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3346  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3347  resultSizes)))
3348  return failure();
3349 
3350  int64_t outputRank = getOutputOperandRank();
3351  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3352  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3353  loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3354  tiledOperands.emplace_back(outputSlice);
3355 
3356  SmallVector<Type> resultTypes;
3357  resultTypes.push_back(tiledOperands[1].getType());
3358  Operation *tiledOp =
3359  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3360 
3361  return TilingResult{
3362  {tiledOp},
3363  SmallVector<Value>(tiledOp->getResults()),
3364  llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3365 }
3366 
3367 //===----------------------------------------------------------------------===//
3368 // WinogradOutputTransformOp
3369 //===----------------------------------------------------------------------===//
3370 
3371 LogicalResult WinogradOutputTransformOp::verify() {
3372  auto valueType = cast<ShapedType>(getValue().getType());
3373  ArrayRef<int64_t> valueShape = valueType.getShape();
3374  int64_t valueH = valueShape[getValueAlphaHDim()];
3375  int64_t valueW = valueShape[getValueAlphaWDim()];
3376  int64_t valueTileH = valueShape[getValueTileHDim()];
3377  int64_t valueTileW = valueShape[getValueTileWDim()];
3378  WinogradConv2DFmr fmr = getFmr();
3379  int64_t m, r;
3380  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3381  bool leftTransform = valueH != 1;
3382  bool rightTransform = valueW != 1;
3383 
3384  int64_t outputRank = getOutputOperandRank();
3385  SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3386  if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3387  expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3388  } else {
3389  if (valueH != (leftTransform ? m + r - 1 : 1))
3390  return emitOpError("expect input height equals to input tile size");
3391  expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3392  }
3393  if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3394  expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3395  } else {
3396  if (valueW != (rightTransform ? m + r - 1 : 1))
3397  return emitOpError("expect input width equals to input tile size");
3398  expectedOutputShape[getOutputWDim()] =
3399  (rightTransform ? m : 1) * valueTileW;
3400  }
3401  expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3402  expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3403 
3404  auto outputType = cast<ShapedType>(getOutput().getType());
3405  ArrayRef<int64_t> outputShape = outputType.getShape();
3406  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3407  return emitOpError("the output shape is not expected");
3408  }
3409  return success();
3410 }
3411 
3413 WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3414  Location loc = getLoc();
3415  IntegerAttr zeroAttr = builder.getIndexAttr(0);
3416  IntegerAttr oneAttr = builder.getIndexAttr(1);
3417  Value value = getValue();
3418  int64_t valueRank = getValueOperandRank();
3419  SmallVector<Range> loopBounds(valueRank);
3420  for (unsigned dim = 0; dim < valueRank; ++dim) {
3421  loopBounds[dim].offset = zeroAttr;
3422  // alphaH, alphaW, tileH, tileW, N, F
3423  loopBounds[dim].size = getDimValue(builder, loc, value, dim);
3424  loopBounds[dim].stride = oneAttr;
3425  }
3426  return loopBounds;
3427 }
3428 
3430 WinogradOutputTransformOp::getLoopIteratorTypes() {
3431  int64_t valueRank = getValueOperandRank();
3432  SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3433  utils::IteratorType::parallel);
3434  return iteratorTypes;
3435 }
3436 
3438  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3439  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3440  SmallVector<OpFoldResult> &resultSizes) {
3441  WinogradConv2DFmr fmr = getFmr();
3442  int64_t m, r;
3443  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3444 
3445  Location loc = getLoc();
3446  MLIRContext *context = builder.getContext();
3447  auto identityAffineMap =
3448  AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3449  auto affineMap =
3450  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3451 
3452  ShapedType valueType = getValueOperandType();
3453  ArrayRef<int64_t> valueShape = valueType.getShape();
3454  int64_t valueH = valueShape[0];
3455  int64_t valueW = valueShape[1];
3456  Value mappedOffsetH = affine::makeComposedAffineApply(
3457  builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3458  offsets[getValueTileHDim()]);
3459  Value mappedOffsetW = affine::makeComposedAffineApply(
3460  builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3461  offsets[getValueTileWDim()]);
3462  Value mappedSizeH = affine::makeComposedAffineApply(
3463  builder, loc, affineMap, sizes[getValueTileHDim()]);
3464  Value mappedSizeW = affine::makeComposedAffineApply(
3465  builder, loc, affineMap, sizes[getValueTileWDim()]);
3466 
3467  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3468  OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3469  OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3470  OpFoldResult sizeH =
3471  valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3472  OpFoldResult sizeW =
3473  valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3474 
3475  resultOffsets.append(
3476  {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3477  resultSizes.append(
3478  {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3479  return success();
3480 }
3481 
3482 /// Implement tiling for winograd_output_transform
3483 /// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3484 /// F). The output of winograd_output_transform is (N, H, W, F) Users can
3485 /// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3486 /// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3487 /// for the sizes of tileH, tileW, N, F for one tile.
3489  OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3490  ArrayRef<OpFoldResult> sizes) {
3491  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3492  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3493  Location loc = getLoc();
3494  SmallVector<Value> tiledOperands;
3495  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3496 
3497  ShapedType valueType = getValueOperandType();
3498  ArrayRef<int64_t> valueShape = valueType.getShape();
3499  int64_t alphaH = valueShape[getValueAlphaHDim()];
3500  int64_t alphaW = valueShape[getValueAlphaWDim()];
3501  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3502  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3503 
3504  sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3505  offsets[getValueTileWDim()], offsets[getValueNDim()],
3506  offsets[getValueFDim()]});
3507  sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3508  sizes[getValueTileWDim()], sizes[getValueNDim()],
3509  sizes[getValueFDim()]});
3510  int64_t valueRank = getValueOperandRank();
3511  SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3512  auto valueSlice = builder.create<tensor::ExtractSliceOp>(
3513  loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3514  tiledOperands.emplace_back(valueSlice);
3515 
3516  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3517  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3518  resultSizes)))
3519  return failure();
3520 
3521  int64_t outputRank = getOutputOperandRank();
3522  SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3523  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3524  loc, getOutput(), resultOffsets, resultSizes, strides);
3525  tiledOperands.emplace_back(outputSlice);
3526 
3527  SmallVector<Type> resultTypes;
3528  resultTypes.push_back(tiledOperands[1].getType());
3529  Operation *tiledOp =
3530  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3531 
3532  return TilingResult{
3533  {tiledOp},
3534  SmallVector<Value>(tiledOp->getResults()),
3535  llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3536 }
3537 
3538 //===----------------------------------------------------------------------===//
3539 // LinalgDialect
3540 // TODO: Merge with the LinalgDialect block at the bottom
3541 //===----------------------------------------------------------------------===//
3542 
3543 // Returns true if the result expression of `subMap` are a subset of `fullMap`.
3544 static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
3545  auto explicitRange = subMap.getResults();
3546  auto defaultRange = fullMap.getResults();
3547  DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
3548  DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
3549  llvm::set_union(explicitSet, defaultSet);
3550  return explicitSet == defaultSet;
3551 }
3552 
3553 /// Check if the user defined map is valid broadcast map. Here broadcast
3554 /// indexing maps are defined in context of corresponding default indexing maps
3555 /// for the given Op. This way the check becomes very simple i.e just check the
3556 /// number of result dims.
3557 /// Returns true if the explictMap is broadcasted with respect to the
3558 /// defaultMap.
3559 static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
3560  return explictMap.getNumResults() < defaultMap.getNumResults();
3561 }
3562 
3563 /// Verifies the broadcast and transpose semantic sepecified by the explicit
3564 /// indexing map for the MatmulOp \p op for each operand specified by \p
3565 /// opIndex.
3566 static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
3567  unsigned opIndex) {
3568  SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
3569  SmallVector<AffineMap, 3> defaultIndexingMaps =
3570  matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3571 
3572  auto opIndexingMap = opIndexingMaps[opIndex];
3573  auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3574  // Check general validity of indexing map results.
3575  if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3576  return matmulOp->emitOpError()
3577  << "Unexpected dim expression in map result.";
3578 
3579  if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3580  if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3581  return matmulOp->emitOpError()
3582  << "Invalid broadcast requested, should be (d2).";
3583  }
3584  return success();
3585  }
3586  return success();
3587 }
3588 
3589 // Check general validity of input indexing map of
3590 // BatchMatmulOp/BatchReduceMatmulOp.
3591 template <typename OpTy>
3592 static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp,
3593  AffineMap opIndexingMap,
3594  AffineMap defaultIndexingMap, bool isLHS) {
3595  assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3596  isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3597  "Expected BatchMatmulOp or BatchReduceMatmulOp");
3598  // Check the result dims are valid.
3599  if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3600  return batchVariantMatmulOp->emitOpError()
3601  << "Unexpected result dim expression (outside the set of default "
3602  "result dims).";
3603 
3604  // Check for valid number of result dims of input maps.
3605  if (opIndexingMap.getNumResults() > 3)
3606  return batchVariantMatmulOp->emitOpError()
3607  << "no. of result dim expressions exceeds 3.";
3608 
3609  auto hasValidBatchDim = [](AffineMap map) {
3610  AffineExpr batchDim = map.getResult(0);
3611  return batchDim.isFunctionOfDim(0);
3612  };
3613 
3614  // Check if the requested broadcast is valid.
3615  if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3616  if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3617  return batchVariantMatmulOp->emitOpError()
3618  << "Invalid broadcast requested.";
3619  } else if (!hasValidBatchDim(opIndexingMap)) {
3620  return batchVariantMatmulOp->emitOpError()
3621  << "Invalid batch dimension expression.";
3622  }
3623  return success();
3624 }
3625 
3626 /// This function checks if the given AffineMap for the output of a
3627 /// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result
3628 /// dimensions and if the output map result dimensions are valid.
3629 template <typename OpTy>
3630 static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp,
3631  AffineMap opIndexingMap) {
3632  assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3633  isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3634  "Expected BatchMatmulOp or BatchReduceMatmulOp");
3635  if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3636  opIndexingMap.getNumResults() != 3) {
3637 
3638  return batchVariantMatmulOp->emitOpError()
3639  << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
3640  << ").";
3641  }
3642  if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3643  opIndexingMap.getNumResults() != 2) {
3644  return batchVariantMatmulOp->emitOpError()
3645  << "expects 2 dims, but got (" << opIndexingMap.getNumResults()
3646  << ").";
3647  }
3648 
3649  auto areValidOutputResultDim = [&](AffineMap outputMap) {
3650  return isa<BatchMatmulOp>(batchVariantMatmulOp)
3651  ? outputMap.getResult(0).isFunctionOfDim(0) &&
3652  outputMap.getResult(1).isFunctionOfDim(1) &&
3653  outputMap.getResult(2).isFunctionOfDim(2)
3654  : outputMap.getResult(0).isFunctionOfDim(1) &&
3655  outputMap.getResult(1).isFunctionOfDim(2);
3656  };
3657 
3658  if (!areValidOutputResultDim(opIndexingMap)) {
3659  return batchVariantMatmulOp->emitOpError()
3660  << "Invalid output map result dimension.";
3661  }
3662 
3663  return success();
3664 }
3665 
3666 /// Verifies the broadcast and transpose semantic specified by the explicit
3667 /// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand
3668 /// specified by opIndex.
3669 template <typename OpTy>
3670 static LogicalResult
3672  unsigned opIndex) {
3673  SmallVector<AffineMap, 3> opIndexingMaps =
3674  batchVariantMatmulOp.getIndexingMapsArray();
3675  SmallVector<AffineMap, 3> defaultIndexingMaps =
3676  batchVariantMatmulOp.getDefaultIndexingMaps(
3677  batchVariantMatmulOp->getContext());
3678 
3679  if (opIndexingMaps.size() != 3)
3680  return batchVariantMatmulOp->emitOpError()
3681  << "Indexing_map attribute must have 3 affine maps.";
3682 
3683  auto opIndexingMap = opIndexingMaps[opIndex];
3684  auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3685 
3686  if (opIndex == 2 &&
3687  failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
3688  return failure();
3689 
3690  if (opIndex != 2 &&
3691  failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
3692  defaultIndexingMap, opIndex == 0)))
3693  return failure();
3694 
3695  return success();
3696 }
3697 
3698 namespace mlir {
3699 namespace linalg {
3700 
3701 std::optional<WinogradConv2DFmr> getWinogradConv2DFmr(int64_t m, int64_t r) {
3702  if (m == 2 && r == 3)
3703  return WinogradConv2DFmr::F_2_3;
3704  if (m == 4 && r == 3)
3705  return WinogradConv2DFmr::F_4_3;
3706  if (m == 2 && r == 5)
3707  return WinogradConv2DFmr::F_2_5;
3708  return std::nullopt;
3709 }
3710 
3711 std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
3712  switch (fmr) {
3713  case WinogradConv2DFmr::F_2_3:
3714  return {2, 3};
3715  case WinogradConv2DFmr::F_4_3:
3716  return {4, 3};
3717  case WinogradConv2DFmr::F_2_5:
3718  return {2, 5};
3719  }
3720 }
3721 
3722 //===----------------------------------------------------------------------===//
3723 // MatMulOp
3724 //===----------------------------------------------------------------------===//
3725 
3726 /// Returns a list of AffineMap with the typical matmul indexing charactristic.
3727 SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3728  AffineExpr d0, d1, d2;
3729  SmallVector<AffineMap> indexingMaps;
3730  bindDims(context, d0, d1, d2);
3731  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
3732  indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
3733  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
3734  return indexingMaps;
3735 }
3736 
3737 SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3738  return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3739  utils::IteratorType::parallel,
3740  utils::IteratorType::reduction};
3741 }
3742 
3743 unsigned MatmulOp::getNumRegionArgs() { return 3; }
3744 
3745 std::string MatmulOp::getLibraryCallName() {
3746  return generateLibraryCallName(getOperation());
3747 }
3748 
3749 bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3750 
3751 /// Check if the op has broadcast and/or transpose semantic. Returns true if
3752 /// the user defined indexing maps are not equal to default map.
3753 bool MatmulOp::hasUserDefinedMaps() {
3754  SmallVector<AffineMap, 3> defaultMaps =
3755  getDefaultIndexingMaps(this->getContext());
3756  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3757  return defaultMaps != explicitMaps;
3758 }
3759 
3760 /// Implements the block region builder for the MatmulOp. This is called by
3761 /// 'fillStructuredOpRegion'.
3762 void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3765  if (emitError && block.getNumArguments() != 3) {
3766  emitError() << "MatmulOp regionBuilder expects 3 args, got "
3767  << block.getNumArguments();
3768  return;
3769  }
3770  assert(block.getNumArguments() == 3 &&
3771  "MatmulOp regionBuilder expects 3 args");
3772  RegionBuilderHelper helper(b, block);
3773  SmallVector<Value> yields;
3774 
3775  TypeFn castVal = TypeFn::cast_signed;
3776  const auto *castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3777  return attr.getName() == "cast";
3778  });
3779  if (castIter != attrs.end()) {
3780  if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3781  castVal = attr.getValue();
3782  }
3783 
3784  Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3785  block.getArgument(0));
3786  Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3787  block.getArgument(1));
3788  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2, emitError);
3789  if (!value3)
3790  return;
3791  Value value4 = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
3792  value3, emitError);
3793  if (!value4)
3794  return;
3795  yields.push_back(value4);
3796  helper.yieldOutputs(yields);
3797 }
3798 
3799 /// Returns true if the given bcastMap map is a valid broadcast map. A valid
3800 /// broadcast map must include K dimension.
3801 /// TODO: Strict inclusion of K dimension in the broadcast map is not
3802 /// necessary for both input matrices simultaneously. We can relax this
3803 /// condition to have K dimension for one input matrix map and infer the K
3804 /// dimension for other input matrix map from the one already having K
3805 /// dimension.
3806 bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3807  assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
3808  AffineExpr expr = bcastMap.getResult(0);
3809  // Invalid map if the common dimension of matmul not found.
3810  return expr.isFunctionOfDim(bcastMap.getNumDims() - 1);
3811 }
3812 
3813 FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
3814  if (parser.parseOptionalKeyword("indexing_maps"))
3815  return ArrayAttr{
3816  nullptr}; // Success in case indexing_maps was not provided.
3817 
3818  ArrayAttr arrayAttr;
3819  if (parser.parseEqual() || parser.parseAttribute(arrayAttr))
3820  return failure();
3821 
3822  if (llvm::any_of(arrayAttr,
3823  [](auto elt) { return !dyn_cast<AffineMapAttr>(elt); }))
3824  return parser.emitError(parser.getCurrentLocation())
3825  << "element of indexing_maps array is not an affine_map";
3826 
3827  return arrayAttr;
3828 }
3829 
3830 ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3831  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3832  if (failed(indexingMapsAttr))
3833  return failure();
3834 
3835  if (*indexingMapsAttr == nullptr) {
3836  auto indexingMapAttrs = llvm::map_to_vector(
3837  MatmulOp::getDefaultIndexingMaps(parser.getContext()),
3838  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3839  indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
3840  }
3841 
3842  result.addAttribute("indexing_maps", *indexingMapsAttr);
3843  return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
3844  MatmulOp::getRegionBuilder());
3845 }
3846 
3847 void MatmulOp::print(OpAsmPrinter &p) {
3848  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
3849  MatmulOp::getDefaultIndexingMaps(getContext()),
3850  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3851  if (!llvm::equal(getIndexingMaps(), indexingMaps))
3852  p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3853 
3854  std::array<StringRef, 3> elidedAttrs = {
3855  "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
3856  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
3857  elidedAttrs);
3858 }
3859 
3860 /// Verify the user defined indexing maps.
3861 LogicalResult MatmulOp::verify() {
3862  // Verification of pure matmul is handled by verifyStructuredOpInterface().
3863  if (!hasUserDefinedMaps())
3864  return success();
3865 
3866  for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
3867  if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
3868  return failure();
3869  }
3870  return success();
3871 }
3872 
3873 LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3874  return memref::foldMemRefCast(*this);
3875 }
3876 
3877 void MatmulOp::getEffects(
3879  &effects) {
3880  if (hasPureTensorSemantics())
3881  return;
3882  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3883 }
3884 
3885 Speculation::Speculatability MatmulOp::getSpeculatability() {
3886  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
3887 }
3888 
3889 //===----------------------------------------------------------------------===//
3890 // ContractOp
3891 //===----------------------------------------------------------------------===//
3892 
3893 SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
3894  AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
3895  // On well-formed IR, indexing_maps is non-empty, contained affine_maps'
3896  // domains are all the same, and each implements a projected permutation.
3897  // Each iteration space dim must occur for at least one operand and either
3898  // takes part in a contraction/reduction or else has parallel iteration type.
3899  // We have that a dim is a contraction/reduction dim if and only if the dim
3900  // occurs for the output operand. We use this fact for fast inference:
3901  // NB: In case we allow dims to occur solely for one input, the above still
3902  // holds: per the einsum semantics, these are reduction dims as well.
3903  SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
3904  for (auto result : outAffineMap.getResults()) {
3905  auto dimExpr = dyn_cast<AffineDimExpr>(result);
3906  assert(dimExpr && "affine_map is a projected permutation");
3907  dimsInOutput[dimExpr.getPosition()] = true;
3908  }
3909 
3910  SmallVector<utils::IteratorType> iteratorTypes;
3911  for (auto dimOccursInOutput : dimsInOutput)
3912  iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
3913  : utils::IteratorType::reduction);
3914 
3915  return iteratorTypes;
3916 }
3917 
3918 unsigned ContractOp::getNumRegionArgs() { return 3; }
3919 
3920 /// Implement block region builder, which is called by 'fillStructuredOpRegion'.
3921 void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3924  if (emitError && block.getNumArguments() != 3) {
3925  emitError() << "ContractOp regionBuilder expects 3 args, got "
3926  << block.getNumArguments();
3927  return;
3928  }
3929  assert(block.getNumArguments() == 3 &&
3930  "ContractOp regionBuilder expects 3 args");
3931  RegionBuilderHelper helper(b, block);
3932 
3933  TypeFn castSignedness = TypeFn::cast_signed;
3934  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3935  return attr.getName() == "cast";
3936  });
3937  if (castIter != attrs.end()) {
3938  if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3939  castSignedness = attr.getValue();
3940  }
3941 
3942  // TODO: Support fields with operators besides mult & add.
3943  Type outType = block.getArgument(2).getType();
3944  Value lhsAtOutType =
3945  helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
3946  Value rhsAtOutType =
3947  helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
3948  Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
3949  rhsAtOutType, emitError);
3950  if (!productAtOutType)
3951  return;
3952  Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
3953  productAtOutType, emitError);
3954  if (!result)
3955  return;
3956  helper.yieldOutputs({result});
3957 }
3958 
3959 ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
3960  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3961  if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
3962  return parser.emitError(parser.getCurrentLocation(),
3963  "expected 'indexing_maps' attribute");
3964  result.addAttribute("indexing_maps", *indexingMapsAttr);
3965 
3966  return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
3967  regionBuilder);
3968 }
3969 
3971  p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3973  p, getOperation(), getInputs(), getOutputs(),
3974  /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
3975 }
3976 
3977 LogicalResult ContractOp::verify() {
3978  int iterationSpaceDims = -1;
3979  // Map iter space dims to #occurrences in inputs' and output's affine_maps:
3980  // e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
3981  // access an input operand (so occurrence count can be at most 2) and
3982  // outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
3983  SmallVector<size_t> inOccurrences;
3984  SmallVector<size_t> outOccurrences;
3985 
3986  // A helper so that for each operand's affine_map and type we check that ...
3987  auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
3988  bool isInput) -> LogicalResult {
3989  // ... the affine_map is a projected permutation;
3990  if (!affineMap.isProjectedPermutation())
3991  return emitError("provided affine_map is not a projected permutation");
3992 
3993  // ... the rank of the affine_map's results and corresponding type match;
3994  if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
3995  if (affineMap.getNumResults() != shapedType.getRank())
3996  return emitError("ranks of shaped operand and results of corresponding "
3997  "affine_map differ");
3998  } else if (affineMap.getNumResults() != 0) {
3999  return emitError("affine_map specifies shaped access while operand has "
4000  "non-shaped type");
4001  }
4002 
4003  // ... the rank of the affine_map's domain is the same as those seen prior;
4004  if (iterationSpaceDims == -1) {
4005  iterationSpaceDims = affineMap.getNumDims();
4006  inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4007  outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4008  } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
4009  return emitError("iteration spaces of provided affine_maps differ");
4010  }
4011 
4012  // ... update counts of dims used to access either an input or the output.
4013  for (AffineExpr affineExpr : affineMap.getResults()) {
4014  auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
4015  if (!affineDimExpr)
4016  llvm_unreachable("affine_map is a projected permutation");
4017 
4018  if (isInput)
4019  inOccurrences[affineDimExpr.getPosition()] += 1;
4020  else
4021  outOccurrences[affineDimExpr.getPosition()] += 1;
4022  }
4023 
4024  return success();
4025  };
4026 
4027  for (auto &&[affineMap, operandType, isInput] :
4028  llvm::zip(getIndexingMapsArray(), getOperandTypes(),
4029  SmallVector<bool>{true, true, false})) {
4030  if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
4031  return failure(); // NB: checkAffineMapAndType will emit relevant error.
4032  }
4033 
4034  bool hasContractingDim = false;
4035  for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
4036  size_t inOccCount = inOccurrences[dimIndex];
4037  size_t outOccCount = outOccurrences[dimIndex];
4038 
4039  // We have a contracting dim if and only if ...
4040  hasContractingDim |= inOccCount == 2 && outOccCount == 0;
4041 
4042  if (inOccCount == 0 && outOccCount == 0)
4043  return emitError() << "iteration space dim at index " << dimIndex
4044  << " not used to access any operand";
4045 
4046  // NB: We disallow a dim which occurs for only one input operand and not
4047  // for the output. In terms of einsum semantics such dims have a
4048  // sensible meaning - namely an additional reduction per each such dim.
4049  // By contrast, the ContractionOpInterface does not know about this
4050  // iter type - cf. inferContractionDims' supported dim kinds. Similarly,
4051  // while vector.contract's verifier accepts dims of this kind many of
4052  // its lowerings give up on encountering these dims.
4053  // TODO: Remove following once we have comprehensive support for input-only
4054  // reduction dims, at both the linalg- and vector-dialect levels.
4055  if (inOccCount == 1 && outOccCount != 1)
4056  return emitError()
4057  << "iteration space dim at index " << dimIndex
4058  << " is neither a contracting dim nor of parallel iteration type";
4059  }
4060 
4061  if (!hasContractingDim)
4062  return emitError("'indexing_maps' do not specify a contracting dimension");
4063 
4064  return success();
4065 }
4066 
4067 LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4068  return memref::foldMemRefCast(*this);
4069 }
4070 
4071 void ContractOp::getEffects(
4073  &effects) {
4074  if (hasPureTensorSemantics())
4075  return;
4076  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4077 }
4078 
4079 Speculation::Speculatability ContractOp::getSpeculatability() {
4080  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4081 }
4082 
4083 //===----------------------------------------------------------------------===//
4084 // Implementation of BatchMatmulOp
4085 //===----------------------------------------------------------------------===//
4087 BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
4088  AffineExpr d0, d1, d2, d3;
4089  SmallVector<AffineMap> indexingMaps;
4090  bindDims(context, d0, d1, d2, d3);
4091  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
4092  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
4093  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
4094  return indexingMaps;
4095 }
4096 
4097 SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
4099  utils::IteratorType::parallel, utils::IteratorType::parallel,
4100  utils::IteratorType::parallel, utils::IteratorType::reduction};
4101 }
4102 
4103 unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
4104 
4105 std::string BatchMatmulOp::getLibraryCallName() {
4106  return generateLibraryCallName(getOperation());
4107 }
4108 
4109 /// Check if the op has broadcast and/or transpose semantic. Returns true if
4110 /// the user defined indexing maps are not equal to default map.
4111 bool BatchMatmulOp::hasUserDefinedMaps() {
4112  SmallVector<AffineMap, 3> defaultMaps =
4113  getDefaultIndexingMaps(this->getContext());
4114  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
4115  return defaultMaps != explicitMaps;
4116 }
4117 
4118 /// Returns true if the given bcastMap map is a valid broadcast map. A valid
4119 /// broadcast map must include K dimension.
4120 /// TODO: Strict inclusion of K dimension in the broadcast map is not
4121 /// necessary for both input matrices simultaneously. We can relax this
4122 /// condition to have K dimension for one input matrix map and infer the K
4123 /// dimension for other input matrix map from the one already having K
4124 /// dimension.
4125 bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
4126  assert(bcastMap.getNumResults() < 3 &&
4127  "Expected less than 3 result dim expr.");
4128  bool isValid = false;
4129  enum Indices { batchPos, mPos, nPos, kPos };
4130  if (bcastMap.getNumResults() == 1) {
4131  AffineExpr expr = bcastMap.getResult(0);
4132  isValid = expr.isFunctionOfDim(kPos);
4133  } else if (bcastMap.getNumResults() == 2) {
4134  AffineExpr expr0 = bcastMap.getResult(0);
4135  AffineExpr expr1 = bcastMap.getResult(1);
4136  isValid =
4137  isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
4138  expr0.isFunctionOfDim(mPos)) &&
4139  expr1.isFunctionOfDim(kPos))
4140  : ((expr0.isFunctionOfDim(batchPos) &&
4141  expr1.isFunctionOfDim(kPos)) ||
4142  (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
4143  }
4144  return isValid;
4145 }
4146 
4147 void BatchMatmulOp::regionBuilder(
4150  if (emitError && block.getNumArguments() != 3) {
4151  emitError() << "BatchMatmulOp regionBuilder expects 3 args, got "
4152  << block.getNumArguments();
4153  return;
4154  }
4155  assert(block.getNumArguments() == 3 &&
4156  "BatchMatmulOp regionBuilder expects 3 args");
4157  RegionBuilderHelper helper(b, block);
4158  SmallVector<Value> yields;
4159 
4160  TypeFn castVal = TypeFn::cast_signed;
4161  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4162  return attr.getName() == "cast";
4163  });
4164  if (castIter != attrs.end()) {
4165  if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4166  castVal = attr.getValue();
4167  }
4168 
4169  auto toType = block.getArgument(2).getType();
4170  Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
4171  Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
4172  Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
4173  Value addVal =
4174  helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
4175  yields.push_back(addVal);
4176  helper.yieldOutputs(yields);
4177 }
4178 
4179 ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
4180  SmallVector<Attribute, 3> indexingMapsAttr;
4181  Attribute mapAttr;
4182  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4183  if (parser.parseEqual())
4184  return failure();
4185 
4186  if (parser.parseLSquare())
4187  return failure();
4188 
4189  do {
4190  if (parser.parseAttribute(mapAttr))
4191  return failure();
4192  if (!isa<AffineMapAttr>(mapAttr)) {
4193  return parser.emitError(parser.getCurrentLocation(),
4194  "expected affine map attribute");
4195  }
4196  indexingMapsAttr.push_back(mapAttr);
4197 
4198  if (parser.parseOptionalComma())
4199  break;
4200  } while (true);
4201 
4202  if (parser.parseRSquare())
4203  return failure();
4204  }
4205  // Initialize indexingMaps, if not supplied explicitly.
4206  if (indexingMapsAttr.empty()) {
4207  indexingMapsAttr = llvm::map_to_vector(
4208  BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
4209  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4210  }
4211  result.addAttribute("indexing_maps",
4212  parser.getBuilder().getArrayAttr(indexingMapsAttr));
4213 
4214  return ::parseNamedStructuredOp(parser, result,
4215  BatchMatmulOp::getNumRegionArgs(),
4216  BatchMatmulOp::getRegionBuilder());
4217 }
4218 
4220  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4221  BatchMatmulOp::getDefaultIndexingMaps(getContext()),
4222  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4223  if (!llvm::equal(getIndexingMaps(), indexingMaps))
4224  p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4225 
4226  std::array<StringRef, 3> elidedAttrs = {
4227  "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4228  ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4229  elidedAttrs);
4230 }
4231 
4232 /// Verify the user defined indexing maps.
4233 LogicalResult BatchMatmulOp::verify() {
4234  // Verification of pure batch_matmul is handled by
4235  // verifyStructuredOpInterface().
4236  if (!hasUserDefinedMaps())
4237  return success();
4238 
4239  for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
4240  if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex)))
4241  return failure();
4242  }
4243  return success();
4244 }
4245 
4246 LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4248  return memref::foldMemRefCast(*this);
4249 }
4250 
4251 void BatchMatmulOp::getEffects(
4253  &effects) {
4254  if (hasPureTensorSemantics())
4255  return;
4256  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4257 }
4258 
4259 Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
4260  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4261 }
4262 
4263 //===----------------------------------------------------------------------===//
4264 // ElementwiseOp
4265 //===----------------------------------------------------------------------===//
4266 //
4267 namespace {
4268 struct ArityGroupAndKind {
4269  // The enum class {Unary, Binary, Ternary, ..}
4270  ElementwiseArityGroup arityGroup;
4271 
4272  // The kind (e.g. `exp` or `add`) belonging to the arity group.
4273  union Kind {
4274  UnaryFn unaryFn;
4275  BinaryFn binaryFn;
4276  TernaryFn ternaryFn;
4277  } kind;
4278 };
4279 
4280 unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4281  return static_cast<unsigned>(arityGroup);
4282 }
4283 } // namespace
4284 
4285 static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind) {
4286  constexpr int lastUnary = static_cast<int>(ElementwiseCaseLimits::LastUnary);
4287  constexpr int lastBinary =
4288  static_cast<int>(ElementwiseCaseLimits::LastBinary);
4289  constexpr int lastTernary =
4290  static_cast<int>(ElementwiseCaseLimits::LastTernary);
4291 
4292  int val = static_cast<int>(kind);
4293  ArityGroupAndKind result;
4294 
4295  if (val < lastUnary) {
4296  result.arityGroup = ElementwiseArityGroup::Unary;
4297  result.kind.unaryFn = static_cast<UnaryFn>(val);
4298  return result;
4299  }
4300  if (val < lastBinary) {
4301  result.arityGroup = ElementwiseArityGroup::Binary;
4302  result.kind.binaryFn = static_cast<BinaryFn>(val - lastUnary);
4303  return result;
4304  }
4305  if (val >= lastTernary) {
4306  llvm_unreachable("unhandled ElementwiseFn");
4307  }
4308  result.arityGroup = ElementwiseArityGroup::Ternary;
4309  result.kind.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
4310  return result;
4311 }
4312 
4313 SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
4314  auto rank = getResultRank();
4315  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
4316 }
4317 
4319 ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
4320  MLIRContext *context) {
4321  auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
4322  return SmallVector<AffineMap>(numMaps, map);
4323 }
4324 
4325 ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
4326  // Expect e.g. `kind = #linalg.elemwise_kind<add>`
4327  Attribute attr;
4328  mlir::linalg::ElementwiseKind elemwiseKindVal;
4329  if (parser.parseKeyword("kind") || parser.parseEqual())
4330  return failure();
4331 
4332  if (succeeded(parser.parseAttribute(attr))) {
4333  auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4334  if (!elemwiseKindAttr)
4335  return parser.emitError(parser.getCurrentLocation(),
4336  "expected ElementwiseKind attribute");
4337  elemwiseKindVal = elemwiseKindAttr.getValue();
4338  } else {
4339  return parser.emitError(parser.getCurrentLocation(),
4340  "expected operation 'kind' attribute");
4341  }
4342  result.addAttribute(
4343  "kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal));
4344 
4345  // Parse optional `indexing_maps`
4346  SmallVector<Attribute, 3> indexingMapsAttr;
4347  Attribute mapAttr;
4348  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4349  if (parser.parseEqual())
4350  return failure();
4351  if (parser.parseLSquare())
4352  return failure();
4353  do {
4354  if (parser.parseAttribute(mapAttr))
4355  return failure();
4356  if (!isa<AffineMapAttr>(mapAttr))
4357  return parser.emitError(parser.getCurrentLocation(),
4358  "expected affine map attribute");
4359  indexingMapsAttr.push_back(mapAttr);
4360  if (parser.parseOptionalComma())
4361  break;
4362  } while (true);
4363  if (parser.parseRSquare())
4364  return failure();
4365  }
4366  // At this stage of parsing the only way to infer number of region
4367  // args is through op kind, as input output tensors are not parsed yet.
4368  auto arityGroupAndKind = getArityGroupAndKind(elemwiseKindVal);
4369  int numRegionArgs =
4370  getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 /*output*/;
4371  if (parseNamedStructuredOp(parser, result, numRegionArgs,
4372  ElementwiseOp::getRegionBuilder())) {
4373  return parser.emitError(parser.getCurrentLocation(),
4374  "unable to parse elemwise op");
4375  }
4376 
4377  // Initialize indexingMaps, if not supplied explicitly.
4378  if (indexingMapsAttr.empty()) {
4379  // We need to infer the numDims of the indexing maps from the output
4380  // type which is already parsed by now.
4381  auto resultType = result.operands[result.operands.size() - 1].getType();
4382  auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4383  if (!shapedType)
4384  return parser.emitError(parser.getCurrentLocation(),
4385  "return type needs to be shaped type");
4386  auto numDims = shapedType.getRank();
4387  indexingMapsAttr = llvm::map_to_vector(
4388  ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4389  parser.getContext()),
4390  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4391  }
4392 
4393  result.addAttribute("indexing_maps",
4394  parser.getBuilder().getArrayAttr(indexingMapsAttr));
4395  return success();
4396 }
4397 
4399  p << " kind=";
4400  p.printAttribute(getKindAttr());
4401  SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
4402  "indexing_maps"};
4403  unsigned arity =
4404  getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup);
4405  unsigned numDims = getResultRank();
4406 
4407  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4408  ElementwiseOp::getDefaultIndexingMaps(arity + 1 /*output*/, numDims,
4409  getContext()),
4410  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4411 
4412  if (!llvm::equal(getIndexingMaps(), indexingMaps))
4413  p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4414 
4415  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4416  elidedAttrs);
4417 }
4418 
4419 LogicalResult ElementwiseOp::verify() {
4420  // All necessary checks are done either by
4421  // - EnumAttr (e.g. unknown operation kind)
4422  // - verifyStructuredOpInterface (incorrect map, sizes).
4423  return success();
4424 }
4425 
4426 /// Implements the block region builder for the ElementwiseOp. This is called by
4427 /// 'fillStructuredOpRegion'.
4428 void ElementwiseOp::regionBuilder(
4431  ElementwiseKind elemwiseKind;
4432  for (auto attr : attrs) {
4433  if (attr.getName() == b.getStringAttr("kind")) {
4434  auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4435  assert(kindAttr && "op kind attribute incorrectly set");
4436  elemwiseKind = kindAttr.getValue();
4437  break;
4438  }
4439  }
4440 
4441  ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind);
4442  auto arityGroup = groupAndKind.arityGroup;
4443  auto kind = groupAndKind.kind;
4444  if (emitError && block.getNumArguments() !=
4445  getArityGroupAsUInt(arityGroup) + 1 /*output*/) {
4446  emitError() << "Elementwise regionBuilder expects "
4447  << (getArityGroupAsUInt(arityGroup) + 1) << " args, got "
4448  << block.getNumArguments();
4449  return;
4450  }
4451  assert(block.getNumArguments() ==
4452  getArityGroupAsUInt(arityGroup) + 1 /*output*/
4453  && "Elementwise regionBuilder number of block args mismatch");
4454 
4455  RegionBuilderHelper helper(b, block);
4456  SmallVector<Value> yields;
4457  Value result;
4458 
4459  if (arityGroup == ElementwiseArityGroup::Unary) {
4460  result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
4461 
4462  } else if (arityGroup == ElementwiseArityGroup::Binary) {
4463  result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
4464  block.getArgument(1));
4465 
4466  } else if (arityGroup == ElementwiseArityGroup::Ternary) {
4467  result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
4468  block.getArgument(1), block.getArgument(2));
4469 
4470  } else {
4471  assert(false && "found unhandled category in elemwise");
4472  }
4473 
4474  yields.push_back(result);
4475  helper.yieldOutputs(yields);
4476 }
4477 
4478 LogicalResult ElementwiseOp::fold(FoldAdaptor,
4480  return memref::foldMemRefCast(*this);
4481 }
4482 
4483 void ElementwiseOp::getEffects(
4485  &effects) {
4486  if (hasPureTensorSemantics())
4487  return;
4488  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4489 }
4490 
4491 Speculation::Speculatability ElementwiseOp::getSpeculatability() {
4492  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4493 }
4494 
4495 //===----------------------------------------------------------------------===//
4496 // PackOp/UnPackOp Common
4497 //===----------------------------------------------------------------------===//
4498 // Given the (potentially) updated packed type, `newPackedTy`, generates an
4499 // updated mixed-tile-sizes attribute. A tile size is updated only
4500 // when:
4501 // * a dim from newPackedTy is static, and
4502 // * the corresponding size from mixedTiles is still dynamic.
4503 // Otherwise, the original tile size is preserved.
4504 // Note - packed-type-dim and mixed-tile-size should always match!
4507  SmallVector<OpFoldResult> mixedTiles) {
4508  SmallVector<OpFoldResult> newMixedTileSizes;
4509  for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
4510  .getShape()
4511  .take_back(mixedTiles.size()),
4512  mixedTiles)) {
4513  int64_t shape = std::get<0>(it);
4514  if (shape == ShapedType::kDynamic) {
4515  newMixedTileSizes.push_back(std::get<1>(it));
4516  continue;
4517  }
4518 
4519  // If the current result dim is static, update the dynamic mixed-size
4520  // (provided the original value is dynamic).
4521  OpFoldResult tile = std::get<1>(it);
4522  if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
4523  // Already a constant
4524  newMixedTileSizes.push_back(tile);
4525  } else {
4526  assert(getConstantIntValue(tile).value() == shape &&
4527  "tile size and dim size don't match!");
4528  newMixedTileSizes.push_back(
4529  (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
4530  }
4531  }
4532 
4533  return newMixedTileSizes;
4534 }
4535 
4536 template <typename OpTy>
4537 static LogicalResult
4539  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4540  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4541  "applies to only pack or unpack operations");
4542  int64_t destRank = op.getDestRank();
4543  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
4544  reifiedReturnShapes[0] =
4545  tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
4546  return success();
4547 }
4548 
4549 template <typename OpTy>
4551  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4552  "applies to only pack or unpack operations");
4553  DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
4554  ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
4555  SmallVector<OpFoldResult> tiles = op.getMixedTiles();
4556  assert(tiles.size() == dimsToTile.size() &&
4557  "tiles must match indices of dimension to block");
4558  // bind the dimension `i` with the tile factor.
4559  for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
4560  dimAndTileMapping[dimsToTile[i]] = tiles[i];
4561  return dimAndTileMapping;
4562 }
4563 
4564 template <typename OpTy>
4566  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4567  "applies to only pack or unpack operations");
4568  Builder builder(op);
4569  SmallVector<OpFoldResult> mixedInnerTiles;
4570  unsigned dynamicValIndex = 0;
4571  for (int64_t staticTile : op.getStaticInnerTiles()) {
4572  if (!ShapedType::isDynamic(staticTile))
4573  mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
4574  else
4575  mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
4576  }
4577  return mixedInnerTiles;
4578 }
4579 
4580 template <typename OpTy>
4582  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4583  "applies to only pack or unpack operations");
4584  SmallVector<Value> dynamicTiles;
4585  SmallVector<int64_t> staticTiles;
4586  dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
4587  return staticTiles;
4588 }
4589 
4590 /// Returns true if `dimsPos` is invalid. It is invalid when:
4591 /// a) It contains duplicate.
4592 /// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
4593 /// c) The number of elements in `dimsPos` is > than `rank`.
4595  size_t rank) {
4596  size_t dimsPosSize = dimsPos.size();
4597  if (dimsPosSize > rank)
4598  return true;
4599  DenseSet<int64_t> uniqued(llvm::from_range, dimsPos);
4600  if (dimsPosSize != uniqued.size())
4601  return true;
4602  return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
4603  return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
4604  });
4605 }
4606 
4607 /// Returns true if the dimension of `sourceShape` is smaller than the dimension
4608 /// of the `limitShape`.
4609 static bool areAllInBound(ArrayRef<int64_t> sourceShape,
4610  ArrayRef<int64_t> limitShape) {
4611  assert(
4612  sourceShape.size() == limitShape.size() &&
4613  "expected source shape rank, and limit of the shape to have same rank");
4614  return llvm::all_of(
4615  llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
4616  int64_t sourceExtent = std::get<0>(it);
4617  int64_t limit = std::get<1>(it);
4618  return ShapedType::isDynamic(sourceExtent) ||
4619  ShapedType::isDynamic(limit) || sourceExtent <= limit;
4620  });
4621 }
4622 
4623 template <typename OpTy>
4624 static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
4625  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4626  "applies to only pack or unpack operations");
4627  Operation *op = packOrUnPack.getOperation();
4628 
4629  // Return true if we have a zero-value tile.
4630  auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
4631  return llvm::any_of(tiles, isZeroInteger);
4632  };
4633 
4634  // Verify tiles. Do not allow zero tiles.
4635  SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
4636  if (hasZeros(mixedTiles))
4637  return op->emitError("invalid zero tile factor");
4638 
4639  // Verify inner_dims_pos and outer_dims_perm.
4640  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4641  ? packOrUnPack.getSourceType()
4642  : packOrUnPack.getDestType();
4643  size_t unpackedRank = unpackedType.getRank();
4644  ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
4645  ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
4647  return op->emitError("invalid inner_dims_pos vector");
4648  if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
4649  return op->emitError("invalid outer_dims_perm vector");
4650  if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
4651  return op->emitError("outer_dims_perm must be a permutation or empty");
4652 
4653  // Tiling factors must be less than or equal to the input rank for pack (or
4654  // output rank for unpack), and must match the number of `inner_dims_pos`.
4655  if (mixedTiles.size() > unpackedRank) {
4656  return op->emitError("tiling factors must be less than or equal to the "
4657  "input rank for pack or output rank for unpack");
4658  }
4659  if (mixedTiles.size() != innerDimsPos.size()) {
4660  return op->emitError(
4661  "tiling factors must equal the number of dimensions to tile");
4662  }
4663 
4664  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4665  ? packOrUnPack.getDestType()
4666  : packOrUnPack.getSourceType();
4667  size_t packedRank = packedType.getRank();
4668  // Require output rank to match input rank + number of blocking factors.
4669  size_t expectedPackedRank = unpackedRank + mixedTiles.size();
4670  if (expectedPackedRank != packedRank) {
4671  return op->emitError(
4672  "packed rank != (unpacked rank + num tiling factors), got ")
4673  << packedRank << " != " << expectedPackedRank;
4674  }
4675 
4676  // Verify result shape is greater than the minimum expected
4677  // by the pack operation, and that the output shape
4678  // represents full tiles.
4679  RankedTensorType expectedPackedType = PackOp::inferPackedType(
4680  unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
4681  if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
4682  return op->emitError("the shape of output is not large enough to hold the "
4683  "packed data. Expected at least ")
4684  << expectedPackedType << ", got " << packedType;
4685  }
4686  if (!llvm::all_of(
4687  llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
4688  mixedTiles),
4689  [](std::tuple<int64_t, OpFoldResult> it) {
4690  int64_t shape = std::get<0>(it);
4691  if (Attribute attr =
4692  llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
4693  IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
4694  int64_t staticTileSize = intAttr.getValue().getSExtValue();
4695  return shape == staticTileSize;
4696  }
4697  return ShapedType::isDynamic(shape);
4698  })) {
4699  return op->emitError("mismatch in inner tile sizes specified and shaped of "
4700  "tiled dimension in the packed type");
4701  }
4702  return success();
4703 }
4704 
4705 namespace {
4706 /// Subset of PackOp/UnPackOp fields used to compute the result of applying
4707 /// various permutations to the op.
4708 // TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
4709 // these. These may or may not become true foldings / canonicalizations
4710 // depending on how aggressive we want to be in automatically folding
4711 // transposes.
4712 struct PackOrUnPackTransposeResult {
4716 };
4717 } // namespace
4718 
4719 template <typename OpTy>
4720 static PackOrUnPackTransposeResult
4722  ArrayRef<int64_t> innerPermutation,
4723  ArrayRef<int64_t> outerPermutation) {
4724  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4725  "applies to only pack or unpack operations");
4726  assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
4727  "some permutation must be non-empty");
4728  PackOrUnPackTransposeResult metadata;
4729  metadata.innerDimsPos =
4730  SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
4731  metadata.innerTiles =
4732  SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
4733  int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
4734  ? packOrUnPackOp.getSourceRank()
4735  : packOrUnPackOp.getDestRank();
4736  metadata.outerDimsPerm =
4737  packOrUnPackOp.getOuterDimsPerm().empty()
4738  ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
4739  : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
4740  if (!innerPermutation.empty()) {
4741  assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
4742  isPermutationVector(innerPermutation) &&
4743  "invalid inner permutation");
4744  applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
4745  applyPermutationToVector(metadata.innerTiles, innerPermutation);
4746  }
4747  if (!outerPermutation.empty()) {
4748  assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
4749  isPermutationVector(outerPermutation) &&
4750  "invalid outer permutation");
4751  applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
4752  }
4753  return metadata;
4754 }
4755 
4756 //===----------------------------------------------------------------------===//
4757 // PackOp
4758 //===----------------------------------------------------------------------===//
4759 
4760 void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
4761  setNameFn(getResult(), "pack");
4762 }
4763 
4764 void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
4767  std::optional<Value> paddingValue,
4769  assert(innerDimsPos.size() == innerTiles.size() &&
4770  "number of tile sizes specified must match the specified number of "
4771  "original dimensions to be tiled");
4772  SmallVector<int64_t> staticTileSizes;
4773  SmallVector<Value> dynamicTileSizes;
4774  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
4775  build(builder, state, dest.getType(), source, dest,
4776  paddingValue ? *paddingValue : nullptr,
4777  outerDimsPerm.empty() ? nullptr
4779  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
4780  builder.getDenseI64ArrayAttr(staticTileSizes));
4781 }
4782 
4783 LogicalResult
4785  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4786  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
4787 }
4788 
4789 DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
4790  return getDimAndTileMappingImpl(*this);
4791 }
4792 
4793 SmallVector<OpFoldResult> PackOp::getMixedTiles() {
4794  return getMixedTilesImpl(*this);
4795 }
4796 
4797 SmallVector<int64_t> PackOp::getStaticTiles() {
4798  return getStaticTilesImpl(*this);
4799 }
4800 
4801 ArrayRef<int64_t> PackOp::getAllOuterDims() {
4802  ShapedType inputType = getSourceType();
4803  int64_t inputRank = inputType.getRank();
4804  return getDestType().getShape().take_front(inputRank);
4805 }
4806 
4807 SmallVector<int64_t> PackOp::getTiledOuterDims() {
4808  auto innerDimsPos = getInnerDimsPos();
4809  auto packedShape = getDestType().getShape();
4811 
4812  for (auto index : innerDimsPos)
4813  res.push_back(packedShape[index]);
4814 
4815  return res;
4816 }
4817 
4818 bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
4820  ArrayRef<int64_t> outputShape,
4823  SmallVector<int64_t> outputTileSizes(
4824  outputShape.take_front(inputShape.size()));
4825  if (!outerDimsPerm.empty()) {
4826  assert(outerDimsPerm.size() == outputTileSizes.size() &&
4827  "expected output and outer_dims_perm to have same size");
4828  applyPermutationToVector(outputTileSizes,
4830  }
4831  for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
4832  if (ShapedType::isDynamic(inputShape[pos]))
4833  continue;
4834  std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
4835 
4836  if (!constantTile) {
4837  if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
4838  (inputShape[pos] % outputTileSizes[pos] != 0))
4839  return true;
4840  } else if (inputShape[pos] % (*constantTile) != 0) {
4841  return true;
4842  }
4843  }
4844  return false;
4845 }
4846 
4847 LogicalResult PackOp::verify() {
4848  if (failed(commonVerifierPackAndUnPackOp(*this)))
4849  return failure();
4850 
4851  // Verify padding value, and bail out if the tile does not divide the
4852  // dimension fully. In the case of dynamic tile factors or dimensions, having
4853  // a partial tile is undefined behavior.
4854  auto paddingValue = getPaddingValue();
4855  if (paddingValue &&
4856  paddingValue.getType() != getSourceType().getElementType()) {
4857  return emitOpError("expected padding_value has ")
4858  << getSourceType().getElementType()
4859  << " but got: " << paddingValue.getType();
4860  }
4861 
4862  if (!paddingValue &&
4863  requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
4864  getDestType().getShape(), getOuterDimsPerm(),
4865  getMixedTiles())) {
4866  return emitOpError(
4867  "invalid tile factor or output size provided. Only full tiles are "
4868  "supported when padding_value is not set");
4869  }
4870  return success();
4871 }
4872 
4873 /// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
4874 /// Value's to kDynamic, even if they are arith.constant values.
4875 static SmallVector<int64_t>
4877  SmallVector<int64_t> result;
4878  for (auto o : ofrs) {
4879  // Have to do this first, as getConstantIntValue special-cases constants.
4880  if (llvm::dyn_cast_if_present<Value>(o))
4881  result.push_back(ShapedType::kDynamic);
4882  else
4883  result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
4884  }
4885  return result;
4886 }
4887 
4888 /// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
4889 /// the packed type. Having a shared helper helps implement these two methods in
4890 /// a way that ensures that they agree on which dimensions are dynamic.
4892  ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
4894  SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
4895  for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4896  if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
4897  continue;
4898  if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
4899  resultShape[tiledDim.value()] = ShapedType::kDynamic;
4900  continue;
4901  }
4902  resultShape[tiledDim.value()] = llvm::divideCeilSigned(
4903  resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
4904  }
4905 
4906  // Swap tile loops if outer_dims_perm is available.
4907  if (!outerDimsPerm.empty())
4909 
4910  // Append the inner tile dimensions.
4911  resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
4912  return resultShape;
4913 }
4914 
4915 SmallVector<OpFoldResult> PackOp::getResultShape(
4916  OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
4919  SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
4920 
4921  AffineExpr s0, s1;
4922  bindSymbols(builder.getContext(), s0, s1);
4923  AffineExpr ceilDivExpr = s0.ceilDiv(s1);
4924  for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4925  resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
4926  builder, loc, ceilDivExpr,
4927  {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
4928  }
4929  if (!outerDimsPerm.empty())
4931  resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
4932 
4933  SmallVector<int64_t> resultTypeShape =
4935  asShapeWithAnyValueAsDynamic(innerTileSizes),
4937 
4938  // Fix-up `resultDims` to ensure that they are Value's if and only if the
4939  // result type shape says it's a dynamic dim. This is needed as callers may
4940  // use dispatchIndexOpFoldResults on the result, and rely on exact number of
4941  // dynamic dims returned by that.
4942  for (unsigned i = 0; i < resultDims.size(); ++i) {
4943  if (!ShapedType::isDynamic(resultTypeShape[i]))
4944  continue;
4945  resultDims[i] =
4946  getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
4947  }
4948 
4949  return resultDims;
4950 }
4951 
4952 /// Get the expected packed type based on source type, tile factors, position of
4953 /// the inner tiles and permutation of the outer tiled loop.
4954 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
4955  ArrayRef<int64_t> innerTileSizes,
4959  sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
4960  return RankedTensorType::get(resultShape, sourceType.getElementType());
4961 }
4962 
4963 Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
4964  ArrayRef<OpFoldResult> innerTileSizes,
4967  AffineExpr dim0, dim1;
4968  bindDims(b.getContext(), dim0, dim1);
4969  auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
4970  return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
4971  {v1, v2});
4972  };
4973 
4974  SmallVector<OpFoldResult> mixedSizes;
4975  for (auto [index, value] : llvm::enumerate(
4976  llvm::cast<RankedTensorType>(source.getType()).getShape())) {
4977  if (ShapedType::isDynamic(value))
4978  mixedSizes.push_back(
4979  b.create<tensor::DimOp>(loc, source, index).getResult());
4980  else
4981  mixedSizes.push_back(b.getIndexAttr(value));
4982  }
4983  for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
4984  int64_t dimPos = std::get<0>(it);
4985  OpFoldResult tileSize = std::get<1>(it);
4986  mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
4987  }
4988  if (!outerDimsPerm.empty())
4989  applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
4990 
4991  mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
4992  auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4993  return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4994 }
4995 
4996 PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
4997  ArrayRef<int64_t> innerPermutation,
4998  ArrayRef<int64_t> outerPermutation) {
4999  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
5000  *this, innerPermutation, outerPermutation);
5001  Value transposedDest =
5002  createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
5003  metadata.innerDimsPos, metadata.outerDimsPerm);
5004  return b.create<PackOp>(loc, getSource(), transposedDest,
5005  metadata.innerDimsPos, metadata.innerTiles,
5006  getPaddingValue(), metadata.outerDimsPerm);
5007 }
5008 
5009 /// Returns true if the tiles and the tiled dims are constant.
5010 template <typename OpTy>
5012  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5013  "applies to only pack or unpack operations");
5014  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5015  ? op.getDestType()
5016  : op.getSourceType();
5017  SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
5018  for (auto [dimDest, tile] : llvm::zip(
5019  packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
5020  std::optional<int64_t> constTileSize = getConstantIntValue(tile);
5021  if (!constTileSize || ShapedType::isDynamic(dimDest))
5022  return false;
5023  }
5024  return true;
5025 }
5026 
5027 Speculation::Speculatability PackOp::getSpeculatability() {
5028  if (getPaddingValue())
5030 
5031  // The verifier rejects already operations if we can statically prove that the
5032  // sizes of the tiles do not divide perfectly the dimension; thus, check only
5033  // to have constant tiles and tiled inner dimensions.
5034  if (!areTilesAndTiledDimsAllConstant(*this))
5036 
5038 }
5039 
5040 // Return true if `inner_dims_pos` and `outer_dims_perm` target the same
5041 // dimensions for pack and unpack.
5042 static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
5043  if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
5044  return false;
5045  if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
5046  return true;
5047  // Outer dims permutation is optional.
5048  // To compare unbalanced pack-unpack pair, treat no permutation as equal to
5049  // identity permutation.
5050  return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
5051  isIdentityPermutation(unPackOp.getOuterDimsPerm());
5052 }
5053 
5054 // Return true if pack and unpack have the same tiles.
5055 // Same SSA values or same integer constants.
5056 static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
5057  auto packTiles = packOp.getMixedTiles();
5058  auto unPackTiles = unPackOp.getMixedTiles();
5059  if (packTiles.size() != unPackTiles.size())
5060  return false;
5061  for (size_t i = 0, e = packTiles.size(); i < e; i++) {
5062  if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
5063  return false;
5064  }
5065  return true;
5066 }
5067 
5068 /// Returns true if the pack op does not need a padding value.
5069 static bool paddingIsNotNeeded(PackOp op) {
5070  auto srcType = op.getSourceType();
5071  if (llvm::any_of(op.getInnerDimsPos(),
5072  [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
5073  return false;
5074  if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
5075  return false;
5076  return !PackOp::requirePaddingValue(
5077  srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
5078  op.getOuterDimsPerm(), op.getMixedTiles());
5079 }
5080 
5081 /// Returns true if the `srcShape` or `destShape` is different from the one in
5082 /// `packOp` and populates each with the inferred static shape.
5083 static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
5084  SmallVectorImpl<int64_t> &destShape) {
5085  bool changeNeeded = false;
5086  srcShape.assign(packOp.getSourceType().getShape().begin(),
5087  packOp.getSourceType().getShape().end());
5088  destShape.assign(packOp.getDestType().getShape().begin(),
5089  packOp.getDestType().getShape().end());
5090  llvm::SmallSetVector<int64_t, 4> innerDims;
5091  innerDims.insert_range(packOp.getInnerDimsPos());
5092  SmallVector<int64_t> inverseOuterDimsPerm;
5093  if (!packOp.getOuterDimsPerm().empty())
5094  inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
5095  int srcRank = packOp.getSourceRank();
5096  for (auto i : llvm::seq<int64_t>(0, srcRank)) {
5097  if (innerDims.contains(i))
5098  continue;
5099  int64_t srcPos = i;
5100  int64_t destPos = i;
5101  if (!inverseOuterDimsPerm.empty())
5102  destPos = inverseOuterDimsPerm[srcPos];
5103  if (ShapedType::isDynamic(srcShape[srcPos]) ==
5104  ShapedType::isDynamic(destShape[destPos])) {
5105  continue;
5106  }
5107  int64_t size = srcShape[srcPos];
5108  if (ShapedType::isDynamic(size))
5109  size = destShape[destPos];
5110  srcShape[srcPos] = size;
5111  destShape[destPos] = size;
5112  changeNeeded = true;
5113  }
5114  return changeNeeded;
5115 }
5116 
5117 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
5118  // Fold an pack(unpack(x)) to x.
5119  if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
5120  if (unPackOp.getSourceType() != packOp.getDestType())
5121  return failure();
5122  if (packOp.getPaddingValue() ||
5123  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
5124  !haveSameTiles(packOp, unPackOp))
5125  return failure();
5126  rewriter.replaceOp(packOp, unPackOp.getSource());
5127  return success();
5128  }
5129 
5130  // Fold optional PaddingValue operand away if padding is not needed.
5131  if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
5132  rewriter.startOpModification(packOp);
5133  packOp.getPaddingValueMutable().clear();
5134  rewriter.finalizeOpModification(packOp);
5135  return success();
5136  }
5137 
5138  // Insert tensor.cast ops if static shape inference is available..
5139  SmallVector<int64_t> srcShape, destShape;
5140  if (inferStaticShape(packOp, srcShape, destShape)) {
5141  Location loc = packOp.getLoc();
5142  Value source = packOp.getSource();
5143  if (srcShape != packOp.getSourceType().getShape()) {
5144  auto newSrcType = packOp.getSourceType().clone(srcShape);
5145  source =
5146  rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
5147  }
5148  Value dest = packOp.getDest();
5149  RankedTensorType originalResultType = packOp.getDestType();
5150  bool needUpdateDestType = (destShape != originalResultType.getShape());
5151  if (needUpdateDestType) {
5152  auto newDestType = packOp.getDestType().clone(destShape);
5153  dest =
5154  rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
5155  }
5156  rewriter.modifyOpInPlace(packOp, [&] {
5157  packOp.getSourceMutable().assign(source);
5158  packOp.getDestMutable().assign(dest);
5159  packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
5160  });
5161  // Insert a cast if needed
5162  if (needUpdateDestType) {
5163  rewriter.setInsertionPointAfter(packOp);
5164  auto castOp =
5165  rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
5166  rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
5167  }
5168  return success();
5169  }
5170 
5171  return failure();
5172 }
5173 
5174 template <typename PackOrUnpackOp>
5175 static bool isLikePadUnPad(PackOrUnpackOp packOp,
5176  RankedTensorType packedTensorType) {
5177  static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5178  std::is_same<PackOrUnpackOp, UnPackOp>::value,
5179  "Function meant for pack/unpack");
5180  // This is a pad if packing only adds ones and we don't transpose dimensions.
5181 
5182  // Check that we are not transposing any dimensions.
5183  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
5184  int64_t numPackedDims = innerDimsPos.size();
5185  auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
5186  if (orderedDims != innerDimsPos) {
5187  // Dimensions don't happen in order.
5188  return false;
5189  }
5190 
5191  ArrayRef<int64_t> packedShape = packedTensorType.getShape();
5192  int64_t packedRank = packedTensorType.getRank();
5193  // At this point we know that we are taking numPackedDims outer
5194  // dimensions and pushing them all the way as the inner most dimensions.
5195  // What's left on the outer most dimensions is, in this order:
5196  // - the factor of the packed dimensions, then
5197  // - the untouched dimensions
5198  // This shifting inward of dimensions is a no-op (as opposed to a transpose)
5199  // if all the dimensions that bubble outerward are ones.
5200  // Therefore check that all the dimensions but the numPackedDims inner most
5201  // ones are ones.
5202  return llvm::all_of(
5203  llvm::seq<int64_t>(0, packedRank - numPackedDims),
5204  [&packedShape](int64_t i) { return packedShape[i] == 1; });
5205 }
5206 
5207 bool PackOp::isLikePad() {
5208  auto packedTensorType =
5209  llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
5210  return isLikePadUnPad(*this, packedTensorType);
5211 }
5212 
5213 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
5214  std::optional<Attribute> paddingValue;
5215  if (auto pad = adaptor.getPaddingValue())
5216  paddingValue = pad;
5217  if (OpFoldResult reshapedSource = reshapeConstantSource(
5218  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5219  getDestType(), paddingValue))
5220  return reshapedSource;
5221  return {};
5222 }
5223 
5224 /// Folds a tensor.cast op into a consuming PackOp op if the
5225 /// `tensor.cast` has source that is more static than the consuming op.
5226 ///
5227 /// Example:
5228 /// ```mlir
5229 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
5230 /// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
5231 /// ```
5232 ///
5233 /// folds into:
5234 ///
5235 /// ```mlir
5236 /// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
5237 /// ```
5238 struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
5240 
5241  LogicalResult matchAndRewrite(PackOp op,
5242  PatternRewriter &rewriter) const override {
5244  return failure();
5245 
5246  SmallVector<Type> newResultTypes(op->getResultTypes());
5247  SmallVector<Value> newOperands =
5249 
5250  // Get the updated mixed-tile-sizes attribute.
5251  SmallVector<OpFoldResult> newMixedTileSizes =
5252  getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
5253 
5254  // Clone op.
5255  // TODO: Strictly speaking, discardable attributes should be _discarded_ at
5256  // this point. However, in practice, we use them for things that we'd like
5257  // to preserve. Implement a better abstraction.
5258  PackOp newOp = rewriter.create<PackOp>(
5259  op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
5260  newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
5261  newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
5262 
5263  // Replace op.
5264  Value oldResult = op.getResult();
5265  Value newResult = newOp.getResult();
5266  Value replacement = (newResult.getType() != oldResult.getType())
5267  ? rewriter.create<tensor::CastOp>(
5268  op->getLoc(), oldResult.getType(), newResult)
5269  : newResult;
5270 
5271  rewriter.replaceOp(op, {replacement});
5272 
5273  return success();
5274  }
5275 };
5276 
5277 //===----------------------------------------------------------------------===//
5278 // UnPackOp
5279 //===----------------------------------------------------------------------===//
5280 
5281 void UnPackOp::getAsmResultNames(
5282  function_ref<void(Value, StringRef)> setNameFn) {
5283  setNameFn(getResult(), "unpack");
5284 }
5285 
5286 LogicalResult
5288  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5289  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
5290 }
5291 
5292 DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
5293  return getDimAndTileMappingImpl(*this);
5294 }
5295 
5296 SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
5297  return getMixedTilesImpl(*this);
5298 }
5299 
5300 SmallVector<int64_t> UnPackOp::getStaticTiles() {
5301  return getStaticTilesImpl(*this);
5302 }
5303 
5304 ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
5305  ShapedType destType = getDestType();
5306  int64_t destRank = destType.getRank();
5307  return getSourceType().getShape().take_front(destRank);
5308 }
5309 
5310 SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
5311  auto innerDimsPos = getInnerDimsPos();
5312  auto packedShape = getSourceType().getShape();
5314 
5315  for (auto index : innerDimsPos)
5316  res.push_back(packedShape[index]);
5317 
5318  return res;
5319 }
5320 
5321 LogicalResult UnPackOp::verify() {
5322  return commonVerifierPackAndUnPackOp(*this);
5323 }
5324 
5325 Speculation::Speculatability UnPackOp::getSpeculatability() {
5326  // See PackOp::getSpeculatability.
5327  if (!areTilesAndTiledDimsAllConstant(*this))
5329 
5331 }
5332 
5333 void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
5337  assert(innerDimsPos.size() == innerTiles.size() &&
5338  "number of tile sizes specified must match the specified number of "
5339  "original dimensions to be tiled");
5340  SmallVector<int64_t> staticTileSizes;
5341  SmallVector<Value> dynamicTileSizes;
5342  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
5343  build(builder, state, dest.getType(), source, dest,
5344  outerDimsPerm.empty() ? nullptr
5346  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
5347  builder.getDenseI64ArrayAttr(staticTileSizes));
5348 }
5349 
5350 Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
5351  Value source,
5352  ArrayRef<OpFoldResult> innerTileSizes,
5355  AffineExpr sym0, sym1;
5356  bindSymbols(b.getContext(), sym0, sym1);
5357  auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
5358  return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
5359  };
5360 
5361  SmallVector<OpFoldResult> mixedSizes;
5362  auto srcType = llvm::cast<RankedTensorType>(source.getType());
5363  for (auto i :
5364  llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
5365  if (srcType.isDynamicDim(i))
5366  mixedSizes.push_back(b.create<tensor::DimOp>(loc, source, i).getResult());
5367  else
5368  mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
5369  }
5370  if (!outerDimsPerm.empty()) {
5371  applyPermutationToVector<OpFoldResult>(
5372  mixedSizes, invertPermutationVector(outerDimsPerm));
5373  }
5374 
5375  for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
5376  mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
5377 
5378  auto elemType = srcType.getElementType();
5379  return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
5380 }
5381 
5382 UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
5383  Value transposedSource,
5384  ArrayRef<int64_t> innerPermutation,
5385  ArrayRef<int64_t> outerPermutation) {
5386  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
5387  *this, innerPermutation, outerPermutation);
5388  return b.create<UnPackOp>(loc, transposedSource, getDest(),
5389  metadata.innerDimsPos, metadata.innerTiles,
5390  metadata.outerDimsPerm);
5391 }
5392 
5393 /// Returns true if the `srcShape` or `destShape` is different from the one in
5394 /// `op` and populates each with the inferred static shape.
5395 static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
5396  SmallVectorImpl<int64_t> &destShape) {
5397  bool changeNeeded = false;
5398  srcShape.assign(op.getSourceType().getShape().begin(),
5399  op.getSourceType().getShape().end());
5400  destShape.assign(op.getDestType().getShape().begin(),
5401  op.getDestType().getShape().end());
5402  llvm::SmallSetVector<int64_t, 4> innerDims;
5403  innerDims.insert_range(op.getInnerDimsPos());
5404  SmallVector<int64_t> inverseOuterDimsPerm;
5405  if (!op.getOuterDimsPerm().empty())
5406  inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
5407  int destRank = op.getDestRank();
5408  for (auto i : llvm::seq<int64_t>(0, destRank)) {
5409  if (innerDims.contains(i))
5410  continue;
5411  int64_t srcPos = i;
5412  int64_t destPos = i;
5413  if (!inverseOuterDimsPerm.empty())
5414  srcPos = inverseOuterDimsPerm[destPos];
5415  if (ShapedType::isDynamic(srcShape[srcPos]) ==
5416  ShapedType::isDynamic(destShape[destPos])) {
5417  continue;
5418  }
5419  int64_t size = srcShape[srcPos];
5420  if (ShapedType::isDynamic(size))
5421  size = destShape[destPos];
5422  srcShape[srcPos] = size;
5423  destShape[destPos] = size;
5424  changeNeeded = true;
5425  }
5426  return changeNeeded;
5427 }
5428 
5429 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
5430  PatternRewriter &rewriter) {
5431  /// unpack(pack(x)) -> x
5432  if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
5433  if (packOp.getSourceType() != unPackOp.getDestType())
5434  return failure();
5435  if (packOp.getPaddingValue() ||
5436  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
5437  !haveSameTiles(packOp, unPackOp))
5438  return failure();
5439  rewriter.replaceOp(unPackOp, packOp.getSource());
5440  return success();
5441  }
5442  /// unpack(destinationStyleOp(x)) -> unpack(x)
5443  if (auto dstStyleOp =
5444  unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
5445  auto destValue = cast<OpResult>(unPackOp.getDest());
5446  Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
5447  rewriter.modifyOpInPlace(unPackOp,
5448  [&]() { unPackOp.setDpsInitOperand(0, newDest); });
5449  return success();
5450  }
5451  /// extract_slice(unpack(x into y)) -> unpack(x into extract_slice(y))
5452  if (unPackOp->hasOneUse()) {
5453  auto extractSliceUser =
5454  dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
5455  if (extractSliceUser &&
5456  areAllConstantIntValue(extractSliceUser.getMixedOffsets(), 0) &&
5457  areAllConstantIntValue(extractSliceUser.getMixedStrides(), 1) &&
5458  extractSliceUser.getSourceType().getRank() ==
5459  extractSliceUser.getResultType().getRank()) {
5460  OpBuilder::InsertionGuard g(rewriter);
5461  rewriter.setInsertionPoint(unPackOp);
5462  auto newDest = rewriter.create<tensor::ExtractSliceOp>(
5463  unPackOp->getLoc(), unPackOp.getDest(),
5464  extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
5465  extractSliceUser.getMixedStrides());
5466  rewriter.modifyOpInPlace(unPackOp, [&]() {
5467  unPackOp.setDpsInitOperand(0, newDest);
5468  unPackOp.getResult().setType(newDest.getType());
5469  });
5470  rewriter.replaceOp(extractSliceUser, unPackOp);
5471  return success();
5472  }
5473  }
5474 
5475  // Insert tensor.cast ops if static shape inference is available..
5476  SmallVector<int64_t> srcShape, destShape;
5477  if (inferStaticShape(unPackOp, srcShape, destShape)) {
5478  Location loc = unPackOp.getLoc();
5479  Value source = unPackOp.getSource();
5480  if (srcShape != unPackOp.getSourceType().getShape()) {
5481  auto newSrcType = unPackOp.getSourceType().clone(srcShape);
5482  source = rewriter.create<tensor::CastOp>(loc, newSrcType,
5483  unPackOp.getSource());
5484  }
5485  Value dest = unPackOp.getDest();
5486  if (destShape != unPackOp.getDestType().getShape()) {
5487  auto newDestType = unPackOp.getDestType().clone(destShape);
5488  dest =
5489  rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
5490  }
5491  Value newOp = rewriter.create<UnPackOp>(
5492  loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
5493  unPackOp.getOuterDimsPerm());
5494  rewriter.replaceOpWithNewOp<tensor::CastOp>(
5495  unPackOp, unPackOp.getResult().getType(), newOp);
5496  return success();
5497  }
5498 
5499  return failure();
5500 }
5501 
5502 bool UnPackOp::isLikeUnPad() {
5503  RankedTensorType packedTensorType = getSourceType();
5504  return isLikePadUnPad(*this, packedTensorType);
5505 }
5506 
5507 OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
5508  if (OpFoldResult reshapedSource = reshapeConstantSource(
5509  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5510  getResult().getType()))
5511  return reshapedSource;
5512  return {};
5513 }
5514 
5515 /// Folds a tensor.cast op into a consuming UnPackOp op if the
5516 /// `tensor.cast` has source that is more static than the consuming op.
5517 ///
5518 /// Example:
5519 /// ```mlir
5520 /// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
5521 /// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
5522 /// ```
5523 ///
5524 /// folds into:
5525 ///
5526 /// ```mlir
5527 /// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
5528 /// ```
5529 struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
5531 
5532  LogicalResult matchAndRewrite(UnPackOp op,
5533  PatternRewriter &rewriter) const override {
5535  return failure();
5536 
5537  SmallVector<Type> newResultTypes(op->getResultTypes());
5538  SmallVector<Value> newOperands =
5540  Value sourceTensor = newOperands[0];
5541 
5542  // Get the updated mixed-tile-sizes attribute.
5543  SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes(
5544  rewriter, sourceTensor.getType(), op.getMixedTiles());
5545 
5546  // Clone op.
5547  // TODO: Strictly speaking, discardable attributes should be _discarded_ at
5548  // this point. However, in practice, we use them for things that we'd like
5549  // to preserve. Implement a better abstraction.
5550  UnPackOp newOp = rewriter.create<UnPackOp>(
5551  op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(),
5552  newMixedTileSizes, op.getOuterDimsPerm());
5553  newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
5554 
5555  // Replace op.
5556  Value oldResult = op.getResult();
5557  Value newResult = newOp.getResult();
5558  Value replacement = (newResult.getType() != oldResult.getType())
5559  ? rewriter.create<tensor::CastOp>(
5560  op->getLoc(), oldResult.getType(), newResult)
5561  : newResult;
5562 
5563  rewriter.replaceOp(op, {replacement});
5564 
5565  return success();
5566  }
5567 };
5568 
5569 //===----------------------------------------------------------------------===//
5570 // BatchReduceMatmulOp
5571 //===----------------------------------------------------------------------===//
5572 SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
5574  utils::IteratorType::reduction, utils::IteratorType::parallel,
5575  utils::IteratorType::parallel, utils::IteratorType::reduction};
5576 }
5577 
5579 BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
5580  AffineExpr d0, d1, d2, d3;
5581  SmallVector<AffineMap> indexingMaps;
5582  bindDims(context, d0, d1, d2, d3);
5583  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
5584  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
5585  indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
5586  return indexingMaps;
5587 }
5588 
5589 unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
5590 
5591 std::string BatchReduceMatmulOp::getLibraryCallName() {
5592  return generateLibraryCallName(getOperation());
5593 }
5594 
5595 /// Check if the op has broadcast and/or transpose semantic. Returns true if
5596 /// the user defined indexing maps are not equal to default map.
5597 bool BatchReduceMatmulOp::hasUserDefinedMaps() {
5598  SmallVector<AffineMap, 3> defaultMaps =
5599  getDefaultIndexingMaps(this->getContext());
5600  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
5601  return defaultMaps != explicitMaps;
5602 }
5603 
5604 /// Returns true if the given bcastMap map is a valid broadcast map. A valid
5605 /// broadcast map must include K dimension.
5606 /// TODO: Strict inclusion of K dimension in the broadcast map is not
5607 /// necessary for both input matrices simultaneously. We can relax this
5608 /// condition to have K dimension for one input matrix map and infer the K
5609 /// dimension for other input matrix map from the one already having K
5610 /// dimension.
5611 bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
5612  bool isLHS) {
5613  assert(bcastMap.getNumResults() < 3 &&
5614  "Expected less than 3 result dim expr.");
5615  bool isValid = false;
5616  enum Indices { batchPos, mPos, nPos, kPos };
5617  if (bcastMap.getNumResults() == 1) {
5618  AffineExpr expr = bcastMap.getResult(0);
5619  isValid = expr.isFunctionOfDim(kPos);
5620  } else if (bcastMap.getNumResults() == 2) {
5621  AffineExpr expr0 = bcastMap.getResult(0);
5622  AffineExpr expr1 = bcastMap.getResult(1);
5623  isValid =
5624  isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
5625  expr0.isFunctionOfDim(mPos)) &&
5626  expr1.isFunctionOfDim(kPos))
5627  : ((expr0.isFunctionOfDim(batchPos) &&
5628  expr1.isFunctionOfDim(kPos)) ||
5629  (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
5630  }
5631  return isValid;
5632 }
5633 
5634 void BatchReduceMatmulOp::regionBuilder(
5637  if (emitError && block.getNumArguments() != 3) {
5638  emitError() << "BatchReduceMatmulOp regionBuilder expects 3 args, got "
5639  << block.getNumArguments();
5640  return;
5641  }
5642  assert(block.getNumArguments() == 3 &&
5643  "BatchReduceMatmulOp regionBuilder expects 3 args");
5644  RegionBuilderHelper helper(b, block);
5645  SmallVector<Value> yields;
5646 
5647  auto toType = block.getArgument(2).getType();
5648  Value castValA =
5649  helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
5650  Value castValB =
5651  helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
5652  Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
5653  Value addVal =
5654  helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
5655  yields.push_back(addVal);
5656  helper.yieldOutputs(yields);
5657 }
5658 
5659 ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
5660  OperationState &result) {
5661  SmallVector<Attribute, 3> indexingMapsAttr;
5662  Attribute mapAttr;
5663  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
5664  if (parser.parseEqual())
5665  return failure();
5666  if (parser.parseLSquare())
5667  return failure();
5668 
5669  do {
5670  if (parser.parseAttribute(mapAttr))
5671  return failure();
5672  if (!isa<AffineMapAttr>(mapAttr)) {
5673  return parser.emitError(parser.getCurrentLocation(),
5674  "expected affine map attribute");
5675  }
5676  indexingMapsAttr.push_back(mapAttr);
5677 
5678  if (parser.parseOptionalComma())
5679  break;
5680  } while (true);
5681 
5682  if (parser.parseRSquare())
5683  return failure();
5684  }
5685  // Initialize indexingMaps, if not supplied explicitly.
5686  if (indexingMapsAttr.empty()) {
5687  indexingMapsAttr = llvm::map_to_vector(
5688  BatchReduceMatmulOp::getDefaultIndexingMaps(parser.getContext()),
5689  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
5690  }
5691  result.addAttribute("indexing_maps",
5692  parser.getBuilder().getArrayAttr(indexingMapsAttr));
5693  return ::parseNamedStructuredOp(parser, result,
5694  BatchReduceMatmulOp::getNumRegionArgs(),
5695  BatchReduceMatmulOp::getRegionBuilder());
5696 }
5697 
5699  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
5700  BatchReduceMatmulOp::getDefaultIndexingMaps(getContext()),
5701  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
5702 
5703  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
5704  p << " indexing_maps = [";
5705  llvm::interleaveComma(getIndexingMaps(), p,
5706  [&](Attribute attr) { p.printAttribute(attr); });
5707  p << "]";
5708  }
5709 
5710  SmallVector<StringRef, 3> elidedAttrs = {
5711  "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
5712  ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
5713  elidedAttrs);
5714 }
5715 
5716 /// Verify the user defined indexing maps.
5717 LogicalResult BatchReduceMatmulOp::verify() {
5718  // Verification of pure batch_reduce_matmul is handled by
5719  // verifyStructuredOpInterface().
5720  if (!hasUserDefinedMaps())
5721  return success();
5722 
5723  for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
5724  if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex)))
5725  return failure();
5726  }
5727  return success();
5728 }
5729 LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
5731  return memref::foldMemRefCast(*this);
5732 }
5733 void BatchReduceMatmulOp::getEffects(
5735  &effects) {
5736  if (hasPureTensorSemantics())
5737  return;
5738  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
5739 }
5740 
5741 Speculation::Speculatability BatchReduceMatmulOp::getSpeculatability() {
5742  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
5743 }
5744 
5745 } // namespace linalg
5746 } // namespace mlir
5747 
5748 //===----------------------------------------------------------------------===//
5749 // LinalgDialect
5750 //===----------------------------------------------------------------------===//
5751 
5752 void LinalgDialect::getCanonicalizationPatterns(
5753  RewritePatternSet &results) const {
5754  results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp, FoldTensorCastPackOp,
5755  FoldTensorCastUnPackOp, InferStaticShapeOfOperands>(getContext());
5756 }
5757 
5759  Attribute value, Type type,
5760  Location loc) {
5761  return arith::ConstantOp::materialize(builder, value, type, loc);
5762 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
Definition: LinalgOps.cpp:3566
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, function_ref< InFlightDiagnostic()> emitError, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
Definition: LinalgOps.cpp:129
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:1947
union mlir::linalg::@1216::ArityGroupAndKind::Kind kind
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
Definition: LinalgOps.cpp:2899
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
Definition: LinalgOps.cpp:2970
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
Definition: LinalgOps.cpp:3544
SmallVector< int64_t > outerDimsPerm
Definition: LinalgOps.cpp:4715
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
Definition: LinalgOps.cpp:2437
SmallVector< OpFoldResult > innerTiles
Definition: LinalgOps.cpp:4714
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
Definition: LinalgOps.cpp:3559
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:320
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
Definition: LinalgOps.cpp:1780
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
Definition: LinalgOps.cpp:1828
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
Definition: LinalgOps.cpp:2944
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
Definition: LinalgOps.cpp:193
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
Definition: LinalgOps.cpp:3630
static Operation * findPayloadOp(Block *body, bool initFirst=false)
Definition: LinalgOps.cpp:1580
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
Definition: LinalgOps.cpp:164
TernaryFn ternaryFn
Definition: LinalgOps.cpp:4276
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
Definition: LinalgOps.cpp:1450
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
Definition: LinalgOps.cpp:209
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
Definition: LinalgOps.cpp:3592
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
Definition: LinalgOps.cpp:2921
ElementwiseArityGroup arityGroup
Definition: LinalgOps.cpp:4270
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
Definition: LinalgOps.cpp:1331
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, LinalgOp linalgOp)
Definition: LinalgOps.cpp:1298
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:2330
SmallVector< int64_t > innerDimsPos
Definition: LinalgOps.cpp:4713
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:364
static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
Definition: LinalgOps.cpp:1073
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
Definition: LinalgOps.cpp:357
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
Definition: LinalgOps.cpp:61
UnaryFn unaryFn
Definition: LinalgOps.cpp:4274
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
Definition: LinalgOps.cpp:1504
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
Definition: LinalgOps.cpp:226
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
Definition: LinalgOps.cpp:1609
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
Definition: LinalgOps.cpp:395
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
Definition: LinalgOps.cpp:3671
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
Definition: LinalgOps.cpp:402
BinaryFn binaryFn
Definition: LinalgOps.cpp:4275
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
Definition: LinalgOps.cpp:246
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder, SMLoc loc)
Definition: LinalgOps.cpp:332
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes, const SetVector< unsigned > &reductionDims)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base type for affine expression.
Definition: AffineExpr.h:68
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
Definition: AffineExpr.cpp:316
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition: AffineMap.h:299
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:330
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:611
unsigned getNumDims() const
Definition: AffineMap.cpp:390
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
unsigned getNumResults() const
Definition: AffineMap.cpp:398
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:407
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:260
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:153
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:158
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:162
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:382
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:107
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:257
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:359
MLIRContext * getContext() const
Definition: Builders.h:55
Location getUnknownLoc()
Definition: Builders.cpp:24
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:261
IndexType getIndexType()
Definition: Builders.cpp:50
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:313
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:729
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:228
This is a value defined by a result of an operation.
Definition: Value.h:447
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:459
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_iterator result_begin()
Definition: Operation.h:413
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_iterator result_end()
Definition: Operation.h:414
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
void setDiscardableAttrs(DictionaryAttr newAttrs)
Set the discardable attribute dictionary on this operation.
Definition: Operation.h:523
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
Block & emplaceBlock()
Definition: Region.h:46
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:810
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:681
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:665
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:593
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:577
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition: Types.cpp:104
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1280
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1333
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2691
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Definition: LinalgOps.cpp:4538
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
Definition: LinalgOps.cpp:5083
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
Definition: LinalgOps.cpp:5175
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
Definition: LinalgOps.cpp:4876
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
Definition: LinalgOps.cpp:4594
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
Definition: LinalgOps.cpp:4609
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
Definition: LinalgOps.cpp:4581
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
Definition: LinalgOps.cpp:4891
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
Definition: LinalgOps.cpp:2429
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
Definition: LinalgOps.cpp:4285
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
Definition: LinalgOps.cpp:4721
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:108
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
Definition: LinalgOps.cpp:2470
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
Definition: LinalgOps.cpp:5069
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
Definition: LinalgOps.cpp:2409
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
Definition: LinalgOps.cpp:2420
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
Definition: LinalgOps.cpp:4550
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
Definition: LinalgOps.cpp:4506
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Definition: LinalgOps.cpp:5056
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:99
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
Definition: LinalgOps.cpp:5042
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
Definition: LinalgOps.cpp:5011
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
Definition: LinalgOps.cpp:3711
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
Definition: LinalgOps.cpp:4565
std::optional< WinogradConv2DFmr > getWinogradConv2DFmr(int64_t m, int64_t r)
Converts the given m and r parameters to a WinogradConv2DFmr enumeration value.
Definition: LinalgOps.cpp:3701
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
Definition: LinalgOps.cpp:4624
FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
Definition: LinalgOps.cpp:3813
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
DynamicAPInt floor(const Fraction &f)
Definition: Fraction.h:77
DynamicAPInt ceil(const Fraction &f)
Definition: Fraction.h:79
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:22
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
Definition: TensorOps.cpp:363
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:356
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
Definition: TensorOps.cpp:372
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:73
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:239
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:784
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:1286
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold transpose with transpose.
Definition: LinalgOps.cpp:2090
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:2093
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
Definition: LinalgOps.cpp:2115
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:2118
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:330
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
Definition: LinalgOps.cpp:5238
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:5241
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
Definition: LinalgOps.cpp:5529
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:5532